Skip to content

Commit

Permalink
fix: try to support mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-friday committed May 15, 2024
1 parent 6fb9d29 commit 727b9ab
Show file tree
Hide file tree
Showing 27 changed files with 353 additions and 234 deletions.
5 changes: 5 additions & 0 deletions .vscode/extensions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"recommendations": [
"ms-python.mypy-type-checker"
]
}
6 changes: 6 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"mypy-type-checker.importStrategy": "fromEnvironment",
"mypy-type-checker.args": [
"--config-file=pyproject.toml"
],
}
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,5 @@ stubPath = "./typings"
[tool.mypy]
python_version = "3.8"
strict = true
disable_error_code = "import-untyped,overload-overlap,override"
disallow_subclassing_any = false
7 changes: 5 additions & 2 deletions src/joblib-stubs/_cloudpickle_wrapper.pyi
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import typing_extensions
from joblib.externals.loky import (
wrap_non_picklable_objects as _wrap_non_picklable_objects,
)

def _my_wrap_non_picklable_objects[T](obj: T, keep_wrapper: bool = ...) -> T: ...
_T = typing_extensions.TypeVar("_T")

def _my_wrap_non_picklable_objects(obj: _T, keep_wrapper: bool = ...) -> _T: ...

if bool(): # noqa: PYI002, UP018
wrap_non_picklable_objects = _my_wrap_non_picklable_objects
else:
wrap_non_picklable_objects = _wrap_non_picklable_objects
wrap_non_picklable_objects = _wrap_non_picklable_objects # type: ignore[assignment]

__all__ = ["wrap_non_picklable_objects"]
44 changes: 28 additions & 16 deletions src/joblib-stubs/_dask.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,25 @@ import weakref
from concurrent import futures
from multiprocessing.pool import AsyncResult as AsyncResult

import typing_extensions
from dask.distributed import Client as Client
from dask.distributed import Future as Future
from distributed.deploy.cluster import Cluster as Cluster
from joblib.parallel import _R
from joblib.parallel import AutoBatchingMixin as AutoBatchingMixin
from joblib.parallel import Parallel as Parallel
from joblib.parallel import ParallelBackendBase as ParallelBackendBase
from joblib.parallel import _ReturnAs as _ReturnAs
from joblib.parallel import parallel_config as parallel_config
from tornado.ioloop import IOLoop as IOLoop

_T = typing_extensions.TypeVar("_T")
_P = typing_extensions.ParamSpec("_P")
_R = typing_extensions.TypeVar(
"_R",
default=typing.Literal["list"],
bound="_ReturnAs", # noqa: PYI020
)

def is_weakrefable(obj: typing.Any) -> bool: ...

class _WeakKeyDictionary:
Expand All @@ -22,21 +31,22 @@ class _WeakKeyDictionary:
def __len__(self) -> int: ...
def clear(self) -> None: ...

type _TaskItem[**P, T] = tuple[
typing.Callable[P, T], list[typing.Any], dict[str, typing.Any]
_TaskItem: typing_extensions.TypeAlias = tuple[
typing.Callable[_P, _T], list[typing.Any], dict[str, typing.Any]
]

class Batch[T]:
def __init__(self, tasks: list[_TaskItem[..., T]]) -> None: ...
def __call__(self, tasks: list[_TaskItem[..., T]] | None = ...) -> list[T]: ...
class Batch(typing.Generic[_T]):
def __init__(self, tasks: list[_TaskItem[..., _T]]) -> None: ...
def __call__(self, tasks: list[_TaskItem[..., _T]] | None = ...) -> list[_T]: ...

type _ScatterIterItem = list[typing.Any] | dict[typing.Any, typing.Any]
_ScatterIterItem: typing_extensions.TypeAlias = (
list[typing.Any] | dict[typing.Any, typing.Any]
)

class DaskDistributedBackend(
AutoBatchingMixin[_R], ParallelBackendBase[_R], typing.Generic[_R]
):
MIN_IDEAL_BATCH_DURATION: typing.ClassVar[float]
MAX_IDEAL_BATCH_DURATION: typing.ClassVar[float]
client: Client
data_futures: dict[int, Future]
wait_for_workers_timeout: float
Expand All @@ -52,19 +62,21 @@ class DaskDistributedBackend(
**submit_kwargs: typing.Any,
) -> None: ...
def __reduce__(self) -> tuple[type[DaskDistributedBackend], tuple[()]]: ...
def get_nested_backend(
def get_nested_backend( # type: ignore[override]
self,
) -> tuple[DaskDistributedBackend, typing.Literal[-1]]: ...
parallel: Parallel
def configure(
parallel: Parallel[_R]
def configure( # type: ignore[override]
self,
n_jobs: int = ...,
parallel: Parallel | None = ...,
parallel: Parallel[_R] | None = ...,
**backend_args: typing.Any,
) -> int: ...
call_data_futures: weakref.WeakKeyDictionary[typing.Any, typing.Any]
def apply_async[T](
def apply_async(
self,
func: typing.Callable[[], T],
callback: typing.Callable[[T], typing.Any] | None = ...,
) -> futures.Future[T]: ...
func: typing.Callable[[], _T],
callback: typing.Callable[[_T], typing.Any] | None = ...,
) -> futures.Future[_T]: ...
# mypy
def effective_n_jobs(self, n_jobs: int) -> int: ...
2 changes: 1 addition & 1 deletion src/joblib-stubs/_memmapping_reducer.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ from joblib.numpy_pickle import load as load
from joblib.numpy_pickle import load_temporary_memmap as load_temporary_memmap
from numpy.typing import ArrayLike, NDArray

type _MmapMode = typing.Literal["r+", "r", "w+", "c"]
_MmapMode: typing_extensions.TypeAlias = typing.Literal["r+", "r", "w+", "c"]

WindowsError: type[OSError | None]
SYSTEM_SHARED_MEM_FS: str
Expand Down
59 changes: 36 additions & 23 deletions src/joblib-stubs/_parallel_backends.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,30 @@ from collections.abc import Generator
from concurrent import futures
from multiprocessing.pool import AsyncResult as AsyncResult

import typing_extensions
from joblib._multiprocessing_helpers import mp as mp
from joblib.executor import get_memmapping_executor as get_memmapping_executor
from joblib.externals.loky import cpu_count as cpu_count
from joblib.externals.loky import process_executor as process_executor
from joblib.externals.loky.process_executor import (
ShutdownExecutorError as ShutdownExecutorError,
)
from joblib.parallel import _R
from joblib.parallel import Parallel as Parallel
from joblib.parallel import _ReturnAs as _ReturnAs
from joblib.pool import MemmappingPool as MemmappingPool

type _Prefer = typing.Literal["processes", "threads"]
type _Require = typing.Literal["sharedmem"]
_T = typing_extensions.TypeVar("_T")
_T_co = typing_extensions.TypeVar("_T_co", covariant=True)
_R = typing_extensions.TypeVar(
"_R",
default=typing.Literal["list"],
bound="_ReturnAs", # noqa: PYI020
)

class _AnyContainer(typing.Protocol[_T_co]): ... # mypy override error

_Prefer: typing_extensions.TypeAlias = typing.Literal["processes", "threads"]
_Require: typing_extensions.TypeAlias = typing.Literal["sharedmem"]

class ParallelBackendBase(typing.Generic[_R], metaclass=ABCMeta):
supports_inner_max_num_threads: typing.ClassVar[bool]
Expand All @@ -39,17 +50,15 @@ class ParallelBackendBase(typing.Generic[_R], metaclass=ABCMeta):
@abstractmethod
def effective_n_jobs(self, n_jobs: int) -> int: ...
@abstractmethod
def apply_async[T](
def apply_async(
self,
func: typing.Callable[[], T],
callback: typing.Callable[[futures.Future[T] | AsyncResult[T]], typing.Any]
| typing.Callable[[futures.Future[T]], typing.Any]
| typing.Callable[[AsyncResult[T]], typing.Any]
func: typing.Callable[[], _T],
callback: typing.Callable[[_AnyContainer[_T]], typing.Any] # FIXME: mypy error
| None = ...,
) -> futures.Future[T] | AsyncResult[T]: ...
def retrieve_result_callback[T](
self, out: futures.Future[T] | AsyncResult[T]
) -> T: ...
) -> _AnyContainer[_T]: ...
def retrieve_result_callback(
self, out: futures.Future[_T] | AsyncResult[_T]
) -> _T: ...
parallel: Parallel[_R]
def configure(
self,
Expand Down Expand Up @@ -85,15 +94,17 @@ class SequentialBackend(ParallelBackendBase[_R], typing.Generic[_R]):
func: typing.Callable[[], typing.Any],
callback: typing.Callable[..., typing.Any] | None = ...,
) -> typing.NoReturn: ...
# mypy
def effective_n_jobs(self, n_jobs: int) -> int: ...

class PoolManagerMixin:
def effective_n_jobs(self, n_jobs: int) -> int: ...
def terminate(self) -> None: ...
def apply_async[T](
def apply_async(
self,
func: typing.Callable[[], T],
callback: typing.Callable[[AsyncResult[T]], typing.Any] | None = ...,
) -> AsyncResult[T]: ...
func: typing.Callable[[], _T],
callback: typing.Callable[[AsyncResult[_T]], typing.Any] | None = ...,
) -> AsyncResult[_T]: ...
def retrieve_result_callback(self, out: typing.Any) -> typing.Any: ...
def abort_everything(self, ensure_ready: bool = ...) -> None: ...

Expand All @@ -110,14 +121,14 @@ class ThreadingBackend(PoolManagerMixin, ParallelBackendBase[_R], typing.Generic
supports_retrieve_callback: typing.ClassVar[bool]
uses_threads: typing.ClassVar[bool]
supports_sharedmem: typing.ClassVar[bool]
def configure(
def configure( # type: ignore[override]
self,
n_jobs: int = ...,
parallel: Parallel[_R] | None = ...,
**backend_args: typing.Any,
) -> int: ...

class MultiprocessingBackend(
class MultiprocessingBackend( # type: ignore[misc] # FIXME
PoolManagerMixin, AutoBatchingMixin[_R], ParallelBackendBase[_R], typing.Generic[_R]
):
supports_retrieve_callback: typing.ClassVar[bool]
Expand All @@ -131,7 +142,7 @@ class MultiprocessingBackend(
**memmappingpool_args: typing.Any,
) -> int: ...

class LokyBackend(AutoBatchingMixin[_R], ParallelBackendBase[_R], typing.Generic[_R]):
class LokyBackend(AutoBatchingMixin[_R], ParallelBackendBase[_R], typing.Generic[_R]): # type: ignore[misc] # FIXME
supports_retrieve_callback: typing.ClassVar[bool]
supports_inner_max_num_threads: typing.ClassVar[bool]
def configure(
Expand All @@ -145,11 +156,13 @@ class LokyBackend(AutoBatchingMixin[_R], ParallelBackendBase[_R], typing.Generic
) -> int: ...
def terminate(self) -> None: ...
def abort_everything(self, ensure_ready: bool = ...) -> None: ...
def apply_async[T](
def apply_async(
self,
func: typing.Callable[[], T],
callback: typing.Callable[[futures.Future[T]], typing.Any] | None = ...,
) -> futures.Future[T]: ...
func: typing.Callable[[], _T],
callback: typing.Callable[[futures.Future[_T]], typing.Any] | None = ...,
) -> futures.Future[_T]: ...
# mypy
def effective_n_jobs(self, n_jobs: int) -> int: ...

class FallbackToBackend(Exception): # noqa: N818
backend: ParallelBackendBase[typing.Any]
Expand Down
19 changes: 17 additions & 2 deletions src/joblib-stubs/_store_backends.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import typing
from abc import ABCMeta, abstractmethod
from datetime import datetime, timedelta

import typing_extensions
from joblib import numpy_pickle as numpy_pickle
from joblib._memmapping_reducer import _MmapMode
from joblib.backports import concurrency_safe_rename as concurrency_safe_rename
Expand All @@ -10,6 +11,8 @@ from joblib.disk import mkdirp as mkdirp
from joblib.disk import rm_subdirs as rm_subdirs
from joblib.logger import format_time as format_time

_T = typing_extensions.TypeVar("_T")

class _ItemInfo(typing.TypedDict, total=True):
location: str

Expand All @@ -20,8 +23,10 @@ class CacheItemInfo(typing.NamedTuple):

class CacheWarning(Warning): ...

def concurrency_safe_write[T](
object_to_write: T, filename: str, write_func: typing.Callable[[T, str], typing.Any]
def concurrency_safe_write(
object_to_write: _T,
filename: str,
write_func: typing.Callable[[_T, str], typing.Any],
) -> str: ...

class StoreBackendBase(metaclass=ABCMeta):
Expand Down Expand Up @@ -77,3 +82,13 @@ class FileSystemStoreBackend(StoreBackendBase, StoreBackendMixin):
compress: bool
mmap_mode: _MmapMode
verbose: int
# mypy
def create_location(self, location: str) -> None: ...
def clear_location(self, location: str) -> None: ...
def get_items(self) -> list[CacheItemInfo]: ...
def configure(
self,
location: str,
verbose: int = ...,
backend_options: dict[str, typing.Any] | None = ...,
) -> None: ...
18 changes: 11 additions & 7 deletions src/joblib-stubs/_utils.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@ import ast
import typing
from dataclasses import dataclass

import typing_extensions
from joblib._multiprocessing_helpers import mp as mp

_T = typing_extensions.TypeVar("_T")
_P = typing_extensions.ParamSpec("_P")

operators: dict[ast.AST, typing.Callable[[typing.Any, typing.Any], typing.Any]]

def eval_expr(expr: str) -> typing.Any: ...
def eval_(node: ast.AST) -> typing.Any: ...
@dataclass(frozen=True)
class _Sentinel[T]:
default_value: T
def __init__(self, default_value: T) -> None: ...
class _Sentinel(typing.Generic[_T]):
default_value: _T
def __init__(self, default_value: _T) -> None: ...

class _TracebackCapturingWrapper[**P, T]:
func: typing.Callable[P, T]
def __init__(self, func: typing.Callable[P, T]) -> None: ...
def __call__(self, **kwargs: typing.Any) -> T: ...
class _TracebackCapturingWrapper(typing.Generic[_P, _T]):
func: typing.Callable[_P, _T]
def __init__(self, func: typing.Callable[_P, _T]) -> None: ...
def __call__(self, **kwargs: typing.Any) -> _T: ...
22 changes: 13 additions & 9 deletions src/joblib-stubs/compressor.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@ import lzma
import typing
from io import BufferedIOBase as BufferedIOBase

import typing_extensions
from _typeshed import ReadableBuffer, WriteableBuffer
from joblib.backports import LooseVersion as LooseVersion
from lz4.frame import LZ4FrameFile # type: ignore

_BufferedIOBaseT = typing_extensions.TypeVar("_BufferedIOBaseT", bound=BufferedIOBase)

LZ4_NOT_INSTALLED_ERROR: str
_COMPRESSORS: dict[str, CompressorWrapper]
_COMPRESSORS: dict[str, CompressorWrapper[typing.Any]]
_ZFILE_PREFIX: bytes
_ZLIB_PREFIX: bytes
_GZIP_PREFIX: bytes
Expand All @@ -18,20 +22,20 @@ _LZMA_PREFIX: bytes
_LZ4_PREFIX: bytes

def register_compressor(
compressor_name: str, compressor: CompressorWrapper, force: bool = ...
compressor_name: str, compressor: CompressorWrapper[typing.Any], force: bool = ...
) -> None: ...

class CompressorWrapper[T: BufferedIOBase]:
fileobj_factory: type[T]
class CompressorWrapper(typing.Generic[_BufferedIOBaseT]):
fileobj_factory: type[_BufferedIOBaseT]
prefix: bytes
extension: str
def __init__(
self, obj: typing.Any, prefix: bytes = ..., extension: str = ...
) -> None: ...
def compressor_file(
self, fileobj: typing.Any, compresslevel: int | None = ...
) -> T: ...
def decompressor_file(self, fileobj: typing.Any) -> T: ...
) -> _BufferedIOBaseT: ...
def decompressor_file(self, fileobj: typing.Any) -> _BufferedIOBaseT: ...

class BZ2CompressorWrapper(CompressorWrapper[bz2.BZ2File]): ...
class LZMACompressorWrapper(CompressorWrapper[lzma.LZMAFile]): ...
Expand Down Expand Up @@ -60,9 +64,9 @@ class BinaryZlibFile(io.BufferedIOBase):
def seekable(self) -> bool: ...
def readable(self) -> bool: ...
def writable(self) -> bool: ...
def read(self, size: int = ...) -> bytes | None: ...
def readinto(self, b: bytes) -> int: ...
def write(self, data: bytes) -> int: ...
def read(self, size: int = ...) -> bytes | None: ... # type: ignore[override]
def readinto(self, b: WriteableBuffer) -> int: ...
def write(self, data: ReadableBuffer) -> int: ...
def seek(self, offset: int, whence: int = ...) -> int: ...
def tell(self) -> int: ...

Expand Down
Loading

0 comments on commit 727b9ab

Please sign in to comment.