diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml index 0a593fbda..833dff91d 100644 --- a/.github/workflows/pytest.yaml +++ b/.github/workflows/pytest.yaml @@ -15,7 +15,7 @@ jobs: fail-fast: false matrix: os: [ubuntu, macos, windows] - python-version: [3.9, "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13"] name: ${{ matrix.os }} - py${{ matrix.python-version }} runs-on: ${{ matrix.os }}-latest defaults: @@ -110,7 +110,7 @@ jobs: fail-fast: false matrix: mne-version: ["1.4.2", "1.5.0"] - python-version: [3.9] + python-version: ["3.10"] name: mne compat ${{ matrix.mne-version }} - py${{ matrix.python-version }} runs-on: ubuntu-latest defaults: diff --git a/README.md b/README.md index 96d0986cd..63b1554ef 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ transmitted wirelessly. For more information about LSL, please visit the # Install -MNE-LSL supports `python ≥ 3.9` and is available on +MNE-LSL supports `python ≥ 3.10` and is available on [PyPI](https://pypi.org/project/mne-lsl/) and on [conda-forge](https://anaconda.org/conda-forge/mne-lsl). Install instruction can be found on the diff --git a/doc/conf.py b/doc/conf.py index be98714ab..52f612e8d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -10,7 +10,6 @@ from datetime import date from importlib import import_module from pathlib import Path -from typing import Optional import mne from intersphinx_registry import get_intersphinx_mapping @@ -210,7 +209,7 @@ # https://www.sphinx-doc.org/en/master/usage/extensions/linkcode.html -def linkcode_resolve(domain: str, info: dict[str, str]) -> Optional[str]: +def linkcode_resolve(domain: str, info: dict[str, str]) -> str | None: """Determine the URL corresponding to a Python object. Parameters diff --git a/doc/resources/install.rst b/doc/resources/install.rst index 3faddee0b..34c9e8576 100644 --- a/doc/resources/install.rst +++ b/doc/resources/install.rst @@ -6,7 +6,7 @@ Install Default install --------------- -``MNE-LSL`` requires Python version ``3.9`` or higher and is available on +``MNE-LSL`` requires Python version ``3.10`` or higher and is available on `PyPI `_ and `conda-forge `_. It requires `liblsl `_ which will be either fetch from the path in the environment variable ``MNE_LSL_LIB`` (or ``PYLSL_LIB``), or from the system directories, or diff --git a/examples/10_peak_detection.py b/examples/10_peak_detection.py index a64c17545..ceae40bc2 100644 --- a/examples/10_peak_detection.py +++ b/examples/10_peak_detection.py @@ -80,6 +80,7 @@ for raw_, label in zip( (raw, raw_notched, raw_bandpassed, raw_lowpassed), ("raw", "notched", "bandpassed", "lowpassed"), + strict=True, ): data, times = raw_[:, start:stop] # select 5 seconds data -= data.mean() # detrend @@ -97,6 +98,7 @@ for raw_, label in zip( (raw, raw_notched, raw_bandpassed, raw_lowpassed), ("raw", "notched", "bandpassed", "lowpassed"), + strict=True, ): data, times = raw_[:, start:stop] # select 5 seconds data -= data.mean() # detrend diff --git a/mne_lsl/datasets/_fetch.py b/mne_lsl/datasets/_fetch.py index 303b74d95..f55794498 100644 --- a/mne_lsl/datasets/_fetch.py +++ b/mne_lsl/datasets/_fetch.py @@ -10,10 +10,9 @@ if TYPE_CHECKING: from pathlib import Path - from typing import Union -def fetch_dataset(path: Path, base_url: str, registry: Union[str, Path]) -> Path: +def fetch_dataset(path: Path, base_url: str, registry: str | Path) -> Path: """Fetch a dataset from the remote. Parameters diff --git a/mne_lsl/datasets/sample.py b/mne_lsl/datasets/sample.py index a12f61790..edb4f690e 100644 --- a/mne_lsl/datasets/sample.py +++ b/mne_lsl/datasets/sample.py @@ -2,7 +2,6 @@ from importlib.resources import files from pathlib import Path -from typing import TYPE_CHECKING import pooch from mne.utils import get_config @@ -10,14 +9,11 @@ from ..utils._checks import ensure_path from ._fetch import fetch_dataset -if TYPE_CHECKING: - from typing import Optional, Union - _REGISTRY: Path = files("mne_lsl.datasets") / "sample-registry.txt" def _make_registry( - folder: Union[str, Path], output: Optional[Union[str, Path]] = None + folder: str | Path, output: str | Path | None = None ) -> None: # pragma: no cover """Create the registry file for the sample dataset. diff --git a/mne_lsl/datasets/testing.py b/mne_lsl/datasets/testing.py index f94aeb83d..06c3a4079 100644 --- a/mne_lsl/datasets/testing.py +++ b/mne_lsl/datasets/testing.py @@ -2,7 +2,6 @@ from importlib.resources import files from pathlib import Path -from typing import TYPE_CHECKING import pooch from mne.utils import get_config @@ -10,15 +9,10 @@ from ..utils._checks import ensure_path from ._fetch import fetch_dataset -if TYPE_CHECKING: - from typing import Optional, Union - _REGISTRY: Path = files("mne_lsl.datasets") / "testing-registry.txt" -def _make_registry( - folder: Union[str, Path], output: Optional[Union[str, Path]] = None -) -> None: +def _make_registry(folder: str | Path, output: str | Path | None = None) -> None: """Create the registry file for the testing dataset. Parameters diff --git a/mne_lsl/datasets/tests/test_fetch.py b/mne_lsl/datasets/tests/test_fetch.py index b594f2816..15a7391c6 100644 --- a/mne_lsl/datasets/tests/test_fetch.py +++ b/mne_lsl/datasets/tests/test_fetch.py @@ -1,19 +1,15 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING import pooch import pytest from mne_lsl.datasets._fetch import fetch_dataset -if TYPE_CHECKING: - from typing import Optional - @pytest.fixture -def license_file() -> Optional[Path]: +def license_file() -> Path | None: """Find the license file if present.""" fname = Path(__file__).parent.parent.parent.parent / "LICENSE" if fname.exists(): diff --git a/mne_lsl/lsl/_utils.py b/mne_lsl/lsl/_utils.py index 35f338ced..3336a97d1 100644 --- a/mne_lsl/lsl/_utils.py +++ b/mne_lsl/lsl/_utils.py @@ -1,5 +1,4 @@ from ctypes import POINTER, c_int, c_void_p, cast -from typing import Optional from .load_liblsl import lib @@ -193,7 +192,7 @@ def free_char_p_array_memory(char_p_array): # noqa: D103 # -- Static checker ----------------------------------------------------------- -def check_timeout(timeout: Optional[float]) -> float: +def check_timeout(timeout: float | None) -> float: """Check that the provided timeout is valid. Parameters diff --git a/mne_lsl/lsl/functions.py b/mne_lsl/lsl/functions.py index 5301d1c7f..baaedfa1f 100644 --- a/mne_lsl/lsl/functions.py +++ b/mne_lsl/lsl/functions.py @@ -1,15 +1,11 @@ from __future__ import annotations from ctypes import byref, c_char_p, c_double, c_void_p -from typing import TYPE_CHECKING from ..utils._checks import check_type, ensure_int from .load_liblsl import lib from .stream_info import _BaseStreamInfo -if TYPE_CHECKING: - from typing import Optional - def library_version() -> int: """Version of the binary LSL library. @@ -55,9 +51,9 @@ def local_clock() -> float: def resolve_streams( timeout: float = 1.0, - name: Optional[str] = None, - stype: Optional[str] = None, - source_id: Optional[str] = None, + name: str | None = None, + stype: str | None = None, + source_id: str | None = None, minimum: int = 1, ) -> list[_BaseStreamInfo]: """Resolve streams on the network. @@ -116,7 +112,7 @@ def resolve_streams( properties = [ # filter out the properties set to None (prop, name) - for prop, name in zip(properties, ("name", "stype", "source_id")) + for prop, name in zip(properties, ("name", "stype", "source_id"), strict=True) if prop is not None ] timeout /= len(properties) diff --git a/mne_lsl/lsl/load_liblsl.py b/mne_lsl/lsl/load_liblsl.py index 594871b37..dca9541c4 100644 --- a/mne_lsl/lsl/load_liblsl.py +++ b/mne_lsl/lsl/load_liblsl.py @@ -21,8 +21,6 @@ from ..utils.logs import logger, warn if TYPE_CHECKING: - from typing import Optional, Union - from pooch import Pooch @@ -71,7 +69,7 @@ def load_liblsl() -> CDLL: return _set_types(lib) -def _load_liblsl_environment_variables() -> Optional[str]: +def _load_liblsl_environment_variables() -> str | None: """Load the binary LSL library from the environment variables. Returns @@ -101,7 +99,7 @@ def _load_liblsl_environment_variables() -> Optional[str]: return None -def _load_liblsl_system() -> Optional[str]: +def _load_liblsl_system() -> str | None: """Load the binary LSL library from the system path/folders. Returns @@ -128,7 +126,7 @@ def _load_liblsl_system() -> Optional[str]: return None -def _load_liblsl_mne_lsl(*, folder: Path = _LIB_FOLDER) -> Optional[str]: +def _load_liblsl_mne_lsl(*, folder: Path = _LIB_FOLDER) -> str | None: """Load the binary LSL library from the system path/folders. Parameters @@ -173,7 +171,7 @@ def _load_liblsl_mne_lsl(*, folder: Path = _LIB_FOLDER) -> Optional[str]: def _fetch_liblsl( *, - folder: Union[str, Path] = _LIB_FOLDER, + folder: str | Path = _LIB_FOLDER, url: str = "https://api.github.com/repos/sccn/liblsl/releases/latest", ) -> str: """Fetch liblsl on the release page. @@ -420,8 +418,8 @@ def _is_valid_libpath(libpath: str) -> bool: def _attempt_load_liblsl( - libpath: Union[str, Path], *, issue_warning: bool = True -) -> tuple[str, Optional[int]]: + libpath: str | Path, *, issue_warning: bool = True +) -> tuple[str, int | None]: """Try loading a binary LSL library. Parameters diff --git a/mne_lsl/lsl/stream_info.py b/mne_lsl/lsl/stream_info.py index 2828b2277..41b7273a7 100644 --- a/mne_lsl/lsl/stream_info.py +++ b/mne_lsl/lsl/stream_info.py @@ -37,7 +37,7 @@ if TYPE_CHECKING: - from typing import Any, Optional, Union + from typing import Any from numpy.typing import DTypeLike @@ -153,7 +153,7 @@ def __repr__(self) -> str: # -- Core information, assigned at construction ------------------------------------ @property - def dtype(self) -> Union[str, DTypeLike]: + def dtype(self) -> str | DTypeLike: """Channel format of a stream. All channels in a stream have the same format. @@ -385,7 +385,9 @@ def get_channel_info(self) -> Info: highpass = filters.child("highpass").first_child().value() lowpass = filters.child("lowpass").first_child().value() with info._unlock(): - for name, value in zip(("highpass", "lowpass"), (highpass, lowpass)): + for name, value in zip( + ("highpass", "lowpass"), (highpass, lowpass), strict=True + ): if len(value) != 0: try: info[name] = float(value) @@ -405,7 +407,7 @@ def get_channel_info(self) -> Info: info["dig"] = dig return info - def get_channel_names(self) -> Optional[list[str]]: + def get_channel_names(self) -> list[str] | None: """Get the channel names in the description. Returns @@ -423,7 +425,7 @@ def get_channel_names(self) -> Optional[list[str]]: """ return self._get_channel_info("ch_name") - def get_channel_types(self) -> Optional[list[str]]: + def get_channel_types(self) -> list[str] | None: """Get the channel types in the description. Returns @@ -441,7 +443,7 @@ def get_channel_types(self) -> Optional[list[str]]: """ return self._get_channel_info("ch_type") - def get_channel_units(self) -> Optional[list[str]]: + def get_channel_units(self) -> list[str] | None: """Get the channel units in the description. Returns @@ -459,7 +461,7 @@ def get_channel_units(self) -> Optional[list[str]]: """ return self._get_channel_info("ch_unit") - def _get_channel_info(self, name: str) -> Optional[list[str]]: + def _get_channel_info(self, name: str) -> list[str] | None: """Get the 'channel/name' element in the XML tree.""" if self.desc.child("channels").empty(): return None @@ -619,7 +621,11 @@ def set_channel_info(self, info: Info) -> None: loc = ch.child("loc") loc = ch.append_child("loc") if loc.empty() else loc _BaseStreamInfo._set_description_node( - loc, {key: value for key, value in zip(_LOC_NAMES, ch_info["loc"])} + loc, + { + key: value + for key, value in zip(_LOC_NAMES, ch_info["loc"], strict=True) + }, ) ch = ch.next_sibling() assert ch.empty() # sanity-check @@ -636,7 +642,7 @@ def set_channel_info(self, info: Info) -> None: if info["dig"] is not None: self._set_digitization(info["dig"]) - def set_channel_names(self, ch_names: Union[list[str], tuple[str]]) -> None: + def set_channel_names(self, ch_names: list[str] | tuple[str, ...]) -> None: """Set the channel names in the description. Existing labels are overwritten. Parameters @@ -646,7 +652,7 @@ def set_channel_names(self, ch_names: Union[list[str], tuple[str]]) -> None: """ self._set_channel_info(ch_names, "ch_name") - def set_channel_types(self, ch_types: Union[str, list[str]]) -> None: + def set_channel_types(self, ch_types: str | list[str] | tuple[str, ...]) -> None: """Set the channel types in the description. Existing types are overwritten. The types are given as human readable strings, e.g. ``'eeg'``. @@ -663,7 +669,14 @@ def set_channel_types(self, ch_types: Union[str, list[str]]) -> None: self._set_channel_info(ch_types, "ch_type") def set_channel_units( - self, ch_units: Union[str, list[str], int, list[int], ScalarIntArray] + self, + ch_units: str + | list[str] + | int + | list[int] + | ScalarIntArray + | tuple[str, ...] + | tuple[int, ...], ) -> None: """Set the channel units in the description. Existing units are overwritten. @@ -704,7 +717,9 @@ def set_channel_units( ] self._set_channel_info(ch_units, "ch_unit") - def _set_channel_info(self, ch_infos: list[str], name: str) -> None: + def _set_channel_info( + self, ch_infos: list[str] | tuple[str, ...], name: str + ) -> None: """Set the 'channel/name' element in the XML tree.""" check_type(ch_infos, (list, tuple), name) for ch_info in ch_infos: @@ -745,7 +760,7 @@ def _set_channel_projectors(self, projs: list[Projection]) -> None: data = projector.append_child("data") if data.empty() else data ch = data.child("channel") for ch_name, ch_data in zip( - proj["data"]["col_names"], np.squeeze(proj["data"]["data"]) + proj["data"]["col_names"], np.squeeze(proj["data"]["data"]), strict=True ): ch = data.append_child("channel") if ch.empty() else ch _BaseStreamInfo._set_description_node( @@ -773,7 +788,11 @@ def _set_digitization(self, dig_points: list[DigPoint]) -> None: if loc.empty(): loc = point.append_child("loc") _BaseStreamInfo._set_description_node( - loc, {key: value for key, value in zip(("X", "Y", "Z"), dig_point["r"])} + loc, + { + key: value + for key, value in zip(("X", "Y", "Z"), dig_point["r"], strict=True) + }, ) point = point.next_sibling() _BaseStreamInfo._prune_description_node(point, dig) @@ -811,10 +830,10 @@ def _set_description_node(node: XMLElement, mapping: dict[str, Any]) -> None: # -- Helper methods to retrieve FIFF elements in the XMLElement tree --------------- @staticmethod def _get_fiff_int_named( - value: Optional[str], + value: str | None, name: str, mapping: dict[int, int], - ) -> Optional[int]: + ) -> int | None: """Try to retrieve the FIFF integer code from the str representation.""" if value is None: return None @@ -908,7 +927,7 @@ def __init__( # ---------------------------------------------------------------------------------- @staticmethod - def _dtype2idxfmt(dtype: Union[str, int, DTypeLike]) -> int: + def _dtype2idxfmt(dtype: str | int | DTypeLike) -> int: """Convert a string format to its LSL integer value.""" if dtype in fmt2idx: return fmt2idx[dtype] diff --git a/mne_lsl/lsl/stream_inlet.py b/mne_lsl/lsl/stream_inlet.py index 6a97992f9..e16f8d63e 100644 --- a/mne_lsl/lsl/stream_inlet.py +++ b/mne_lsl/lsl/stream_inlet.py @@ -18,7 +18,6 @@ if TYPE_CHECKING: from collections.abc import Sequence - from typing import Optional, Union from numpy.typing import DTypeLike, NDArray @@ -63,7 +62,7 @@ def __init__( chunk_size: int = 0, max_buffered: float = 360, recover: bool = True, - processing_flags: Optional[Union[str, Sequence[str]]] = None, + processing_flags: str | Sequence[str] | None = None, ): check_type(sinfo, (_BaseStreamInfo,), "sinfo") chunk_size = ensure_int(chunk_size, "chunk_size") @@ -166,7 +165,7 @@ def __del__(self): self._del() # no-op if called more than once logger.debug(f"Deleting {self.__class__.__name__}.") - def open_stream(self, timeout: Optional[float] = None) -> None: + def open_stream(self, timeout: float | None = None) -> None: """Subscribe to a data stream. All samples pushed in at the other end from this moment onwards will be queued @@ -222,7 +221,7 @@ def close_stream(self) -> None: logger.debug("Closing stream, lib.lsl_close_stream(self._obj) done.") self._stream_is_open = False - def time_correction(self, timeout: Optional[float] = None) -> float: + def time_correction(self, timeout: float | None = None) -> float: """Retrieve an estimated time correction offset for the given stream. The first call to this function takes several milliseconds until a reliable @@ -253,8 +252,8 @@ def time_correction(self, timeout: Optional[float] = None) -> float: return result def pull_sample( - self, timeout: Optional[float] = 0.0 - ) -> tuple[Union[list[str], ScalarArray], Optional[float]]: + self, timeout: float | None = 0.0 + ) -> tuple[list[str] | ScalarArray, float | None]: """Pull a single sample from the inlet. Parameters @@ -309,9 +308,9 @@ def pull_sample( def pull_chunk( self, - timeout: Optional[float] = 0.0, + timeout: float | None = 0.0, max_samples: int = 1024, - ) -> tuple[Union[list[list[str]], ScalarArray], NDArray[np.float64]]: + ) -> tuple[list[list[str]] | ScalarArray, NDArray[np.float64]]: """Pull a chunk of samples from the inlet. Parameters @@ -433,7 +432,7 @@ def _obj(self, obj): # ---------------------------------------------------------------------------------- @copy_doc(_BaseStreamInfo.dtype) @property - def dtype(self) -> Union[str, DTypeLike]: + def dtype(self) -> str | DTypeLike: return fmt2numpy.get(self._dtype, "string") @copy_doc(_BaseStreamInfo.n_channels) @@ -475,7 +474,7 @@ def was_clock_reset(self) -> bool: return bool(lib.lsl_was_clock_reset(self._obj)) # ---------------------------------------------------------------------------------- - def get_sinfo(self, timeout: Optional[float] = None) -> _BaseStreamInfo: + def get_sinfo(self, timeout: float | None = None) -> _BaseStreamInfo: """:class:`~mne_lsl.lsl.StreamInfo` corresponding to this Inlet. Parameters diff --git a/mne_lsl/lsl/stream_outlet.py b/mne_lsl/lsl/stream_outlet.py index f09fd0b9e..5f0807b5a 100644 --- a/mne_lsl/lsl/stream_outlet.py +++ b/mne_lsl/lsl/stream_outlet.py @@ -15,8 +15,6 @@ from .stream_info import _BaseStreamInfo if TYPE_CHECKING: - from typing import Optional, Union - from numpy.typing import DTypeLike from .._typing import ScalarArray, ScalarFloatArray @@ -110,7 +108,7 @@ def __del__(self): def push_sample( self, - x: Union[list[str], ScalarArray], + x: list[str] | ScalarArray, timestamp: float = 0.0, pushThrough: bool = True, ) -> None: @@ -163,8 +161,8 @@ def push_sample( def push_chunk( self, - x: Union[list[list[str]], ScalarArray], - timestamp: Optional[Union[float, ScalarFloatArray]] = None, + x: list[list[str]] | ScalarArray, + timestamp: float | ScalarFloatArray | None = None, pushThrough: bool = True, ) -> None: """Push a chunk of samples into the :class:`~mne_lsl.lsl.StreamOutlet`. @@ -248,7 +246,7 @@ def push_chunk( ) timestamp_c = (c_double * timestamp.size)(*timestamp.astype(np.float64)) liblsl_push_chunk_func = self._do_push_chunk_n - elif isinstance(timestamp, (float, int)): + elif isinstance(timestamp, (float | int)): if self.sfreq == 0.0 and n_samples != 1 and timestamp != 0: warn( "The stream is irregularly sampled and timestamp is a float and " @@ -274,7 +272,7 @@ def push_chunk( ) ) - def wait_for_consumers(self, timeout: Optional[float]) -> bool: + def wait_for_consumers(self, timeout: float | None) -> bool: """Wait (block) until at least one :class:`~mne_lsl.lsl.StreamInlet` connects. Parameters @@ -311,7 +309,7 @@ def _obj(self, obj): @copy_doc(_BaseStreamInfo.dtype) @property - def dtype(self) -> Union[str, DTypeLike]: + def dtype(self) -> str | DTypeLike: return fmt2numpy.get(self._dtype, "string") @copy_doc(_BaseStreamInfo.n_channels) diff --git a/mne_lsl/lsl/tests/test_load_liblsl.py b/mne_lsl/lsl/tests/test_load_liblsl.py index 8d63ca176..3c60aa846 100644 --- a/mne_lsl/lsl/tests/test_load_liblsl.py +++ b/mne_lsl/lsl/tests/test_load_liblsl.py @@ -74,6 +74,10 @@ def liblsl_outdated(tmp_path, download_liblsl_outdated) -> Path: return tmp_path / download_liblsl_outdated.name +@pytest.mark.skipif( + _PLATFORM == "linux", + reason="Runner ubuntu-latest runs on 24.04 and LSL did not release yet for it.", +) @pytest.mark.skipif( _PLATFORM == "windows", reason="PermissionError: [WinError 5] Access is denied (on Path.unlink(...)).", diff --git a/mne_lsl/player/_base.py b/mne_lsl/player/_base.py index f7181efcd..4b17462e4 100644 --- a/mne_lsl/player/_base.py +++ b/mne_lsl/player/_base.py @@ -27,8 +27,9 @@ from ..utils.meas_info import _set_channel_units if TYPE_CHECKING: + from collections.abc import Callable from datetime import datetime - from typing import Any, Callable, Optional, Union + from typing import Any from mne import Info @@ -54,9 +55,9 @@ class BasePlayer(ABC, ContainsMixin, SetChannelsMixin): @abstractmethod def __init__( self, - fname: Union[str, Path, BaseRaw], + fname: str | Path | BaseRaw, chunk_size: int = 10, - n_repeat: Union[int, float] = np.inf, + n_repeat: int | float = np.inf, ) -> None: self._chunk_size = ensure_int(chunk_size, "chunk_size") if self._chunk_size <= 0: @@ -88,10 +89,10 @@ def __init__( @fill_doc def anonymize( self, - daysback: Optional[int] = None, + daysback: int | None = None, keep_his: bool = False, *, - verbose: Optional[Union[bool, str, int]] = None, + verbose: bool | str | int | None = None, ) -> BasePlayer: """Anonymize the measurement information in-place. @@ -157,10 +158,10 @@ def get_channel_units( @fill_doc def rename_channels( self, - mapping: Union[dict[str, str], Callable], + mapping: dict[str, str] | Callable, allow_duplicates: bool = False, *, - verbose: Optional[Union[bool, str, int]] = None, + verbose: bool | str | int | None = None, ) -> BasePlayer: """Rename channels. @@ -208,7 +209,7 @@ def set_channel_types( mapping: dict[str, str], *, on_unit_change: str = "warn", - verbose: Optional[Union[bool, str, int]] = None, + verbose: bool | str | int | None = None, ) -> BasePlayer: """Define the sensor type of channels. @@ -244,7 +245,7 @@ def set_channel_types( return self @abstractmethod - def set_channel_units(self, mapping: dict[str, Union[str, int]]) -> BasePlayer: + def set_channel_units(self, mapping: dict[str, str | int]) -> BasePlayer: """Define the channel unit multiplication factor. By convention, MNE stores data in SI units. But systems often stream in non-SI @@ -292,7 +293,7 @@ def set_channel_units(self, mapping: dict[str, Union[str, int]]) -> BasePlayer: return self def set_meas_date( - self, meas_date: Optional[Union[datetime, float, tuple[float, float]]] + self, meas_date: datetime | float | tuple[float, float] | None ) -> BasePlayer: """Set the measurement start date. @@ -397,7 +398,7 @@ def chunk_size(self) -> int: return self._chunk_size @property - def fname(self) -> Optional[Path]: + def fname(self) -> Path | None: """Path to file played. :type: :class:`~pathlib.Path` | None @@ -413,7 +414,7 @@ def info(self) -> Info: return self._raw.info @property - def n_repeat(self) -> Optional[int]: + def n_repeat(self) -> int | None: """Number of times the file is repeated. :type: :class:`int` | ``np.inf`` diff --git a/mne_lsl/player/_base.pyi b/mne_lsl/player/_base.pyi index ec4ecd529..69a08c64c 100644 --- a/mne_lsl/player/_base.pyi +++ b/mne_lsl/player/_base.pyi @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod +from collections.abc import Callable from datetime import datetime as datetime from pathlib import Path -from typing import Any, Callable +from typing import Any from _typeshed import Incomplete from mne import Info diff --git a/mne_lsl/player/player_lsl.py b/mne_lsl/player/player_lsl.py index 027509106..3cffebc0d 100644 --- a/mne_lsl/player/player_lsl.py +++ b/mne_lsl/player/player_lsl.py @@ -15,8 +15,8 @@ from ._base import BasePlayer if TYPE_CHECKING: + from collections.abc import Callable from pathlib import Path - from typing import Callable, Optional, Union @fill_doc @@ -92,13 +92,13 @@ class PlayerLSL(BasePlayer): def __init__( self, - fname: Union[str, Path], + fname: str | Path, chunk_size: int = 10, - n_repeat: Union[int, float] = np.inf, + n_repeat: int | float = np.inf, *, - name: Optional[str] = None, + name: str | None = None, source_id: str = "MNE-LSL", - annotations: Optional[bool] = None, + annotations: bool | None = None, ) -> None: super().__init__(fname, chunk_size, n_repeat) check_type(name, (str, None), "name") @@ -159,10 +159,10 @@ def __init__( @copy_doc(BasePlayer.rename_channels) def rename_channels( self, - mapping: Union[dict[str, str], Callable], + mapping: dict[str, str] | Callable, allow_duplicates: bool = False, *, - verbose: Optional[Union[bool, str, int]] = None, + verbose: bool | str | int | None = None, ) -> PlayerLSL: super().rename_channels(mapping, allow_duplicates) self._sinfo.set_channel_names(self.info["ch_names"]) @@ -193,7 +193,7 @@ def set_channel_types( mapping: dict[str, str], *, on_unit_change: str = "warn", - verbose: Optional[Union[bool, str, int]] = None, + verbose: bool | str | int | None = None, ) -> PlayerLSL: super().set_channel_types( mapping, on_unit_change=on_unit_change, verbose=verbose @@ -202,7 +202,7 @@ def set_channel_types( return self @copy_doc(BasePlayer.set_channel_units) - def set_channel_units(self, mapping: dict[str, Union[str, int]]) -> PlayerLSL: + def set_channel_units(self, mapping: dict[str, str | int]) -> PlayerLSL: super().set_channel_units(mapping) ch_units_after = np.array( [ch["unit_mul"] for ch in self.info["chs"]], dtype=np.int8 diff --git a/mne_lsl/player/player_lsl.pyi b/mne_lsl/player/player_lsl.pyi index 4129ceeff..b02c8f506 100644 --- a/mne_lsl/player/player_lsl.pyi +++ b/mne_lsl/player/player_lsl.pyi @@ -1,5 +1,5 @@ +from collections.abc import Callable from pathlib import Path as Path -from typing import Callable from _typeshed import Incomplete from mne import Annotations diff --git a/mne_lsl/player/tests/test_player_lsl.py b/mne_lsl/player/tests/test_player_lsl.py index 33bceae32..c3fe51400 100644 --- a/mne_lsl/player/tests/test_player_lsl.py +++ b/mne_lsl/player/tests/test_player_lsl.py @@ -432,7 +432,9 @@ def test_player_annotations(raw_annotations, close_io, chunk_size, request): assert stream.info["ch_names"] == annotations assert stream.get_channel_types() == ["misc"] * sinfo.n_channels time.sleep(3) # acquire some annotations - for single, duration in zip(("bad_test", "test2", "test3"), (0.4, 0.1, 0.05)): + for single, duration in zip( + ("bad_test", "test2", "test3"), (0.4, 0.1, 0.05), strict=True + ): data, ts = stream.get_data(picks=single) data = data.squeeze() assert ts.size == data.size diff --git a/mne_lsl/stream/_filters.py b/mne_lsl/stream/_filters.py index a59b5cd21..bcdd14daa 100644 --- a/mne_lsl/stream/_filters.py +++ b/mne_lsl/stream/_filters.py @@ -10,7 +10,7 @@ from ..utils.logs import logger, warn if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any class StreamFilter(dict): @@ -91,8 +91,8 @@ def __ne__(self, other: Any): # explicit method required to issue warning def create_filter( sfreq: float, - l_freq: Optional[float], - h_freq: Optional[float], + l_freq: float | None, + h_freq: float | None, iir_params: dict[str, Any], ) -> dict[str, Any]: """Create an IIR causal filter. @@ -133,7 +133,7 @@ def create_filter( def ensure_sos_iir_params( - iir_params: Optional[dict[str, Any]] = None, + iir_params: dict[str, Any] | None = None, ) -> dict[str, Any]: """Ensure that the filter parameters include SOS output.""" if iir_params is None: diff --git a/mne_lsl/stream/base.py b/mne_lsl/stream/base.py index 8e795dbf1..45628014b 100644 --- a/mne_lsl/stream/base.py +++ b/mne_lsl/stream/base.py @@ -33,8 +33,9 @@ from ._filters import StreamFilter, create_filter, ensure_sos_iir_params if TYPE_CHECKING: + from collections.abc import Callable from datetime import datetime - from typing import Any, Callable, Optional, Union + from typing import Any from mne import Info from mne.channels import DigMontage @@ -110,10 +111,8 @@ def acquire(self) -> None: @fill_doc def add_reference_channels( self, - ref_channels: Union[str, list[str], tuple[str]], - ref_units: Optional[ - Union[str, int, list[Union[str, int]], tuple[Union[str, int]]] - ] = None, + ref_channels: str | list[str] | tuple[str, ...], + ref_units: str | int | list[str | int] | tuple[str | int, ...] | None = None, ) -> BaseStream: """Add EEG reference channels to data that consists of all zeros. @@ -156,7 +155,7 @@ def add_reference_channels( # error checking and conversion of the arguments to valid values if isinstance(ref_channels, str): ref_channels = [ref_channels] - if isinstance(ref_units, (str, int)): + if isinstance(ref_units, (str | int)): ref_units = [ref_units] elif ref_units is None: ref_units = [0] * len(ref_channels) @@ -238,10 +237,10 @@ def add_reference_channels( @fill_doc def anonymize( self, - daysback: Optional[int] = None, + daysback: int | None = None, keep_his: bool = False, *, - verbose: Optional[Union[bool, str, int]] = None, + verbose: bool | str | int | None = None, ) -> BaseStream: """Anonymize the measurement information in-place. @@ -340,7 +339,7 @@ def disconnect(self) -> BaseStream: # This method needs to close any inlet/network object and need to end with # self._reset_variables(). - def del_filter(self, idx: Union[int, list[int], tuple[int], str] = "all") -> None: + def del_filter(self, idx: int | list[int] | tuple[int, ...] | str = "all") -> None: """Remove a filter from the list of applied filters. Parameters @@ -373,7 +372,7 @@ def del_filter(self, idx: Union[int, list[int], tuple[int], str] = "all") -> Non ) elif idx == "all": idx = np.arange(len(self._filters), dtype=np.uint8) - elif isinstance(idx, (tuple, list)): + elif isinstance(idx, (tuple | list)): for elt in idx: check_type(elt, ("int-like",), "idx") idx = np.array(idx, dtype=np.uint8) @@ -419,7 +418,7 @@ def del_filter(self, idx: Union[int, list[int], tuple[int], str] = "all") -> Non for k in idx[::-1]: del self._filters[k] - def drop_channels(self, ch_names: Union[str, list[str], tuple[str]]) -> BaseStream: + def drop_channels(self, ch_names: str | list[str] | tuple[str, ...]) -> BaseStream: """Drop channel(s). Parameters @@ -456,12 +455,12 @@ def drop_channels(self, ch_names: Union[str, list[str], tuple[str]]) -> BaseStre @fill_doc def filter( self, - l_freq: Optional[float], - h_freq: Optional[float], - picks: Optional[Union[str, list[str], int, list[int], ScalarIntArray]] = None, - iir_params: Optional[dict[str, Any]] = None, + l_freq: float | None, + h_freq: float | None, + picks: str | list[str] | int | list[int] | ScalarIntArray | None = None, + iir_params: dict[str, Any] | None = None, *, - verbose: Optional[Union[bool, str, int]] = None, + verbose: bool | str | int | None = None, ) -> BaseStream: # noqa: A003 """Filter the stream with an IIR causal filter. @@ -517,7 +516,7 @@ def filter( @copy_doc(ContainsMixin.get_channel_types) def get_channel_types( self, - picks: Optional[Union[str, list[str], int, list[int], ScalarIntArray]] = None, + picks: str | list[str] | int | list[int] | ScalarIntArray | None = None, unique=False, only_data_chs=False, ) -> list[str]: @@ -529,7 +528,7 @@ def get_channel_types( @fill_doc def get_channel_units( self, - picks: Optional[Union[str, list[str], int, list[int], ScalarIntArray]] = None, + picks: str | list[str] | int | list[int] | ScalarIntArray | None = None, only_data_chs: bool = False, ) -> list[tuple[int, int]]: """Get a list of channel unit for each channel. @@ -562,9 +561,9 @@ def get_channel_units( @fill_doc def get_data( self, - winsize: Optional[float] = None, - picks: Optional[Union[str, list[str], int, list[int], ScalarIntArray]] = None, - exclude: Union[str, list[str], tuple[str]] = "bads", + winsize: float | None = None, + picks: str | list[str] | int | list[int] | ScalarIntArray | None = None, + exclude: str | list[str] | tuple[str, ...] = "bads", ) -> tuple[ScalarArray, NDArray[np.float64]]: """Retrieve the latest data from the buffer. @@ -630,7 +629,7 @@ def get_data( raise # pragma: no cover @copy_doc(SetChannelsMixin.get_montage) - def get_montage(self) -> Optional[DigMontage]: + def get_montage(self) -> DigMontage | None: self._check_connected("get_montage()") return super().get_montage() @@ -639,12 +638,12 @@ def get_montage(self) -> Optional[DigMontage]: def notch_filter( self, freqs: float, - picks: Optional[Union[str, list[str], int, list[int], ScalarIntArray]] = None, - notch_widths: Optional[float] = None, + picks: str | list[str] | int | list[int] | ScalarIntArray | None = None, + notch_widths: float | None = None, trans_bandwidth=1, - iir_params: Optional[dict[str, Any]] = None, + iir_params: dict[str, Any] | None = None, *, - verbose: Optional[Union[bool, str, int]] = None, + verbose: bool | str | int | None = None, ) -> BaseStream: """Filter the stream with an IIR causal notch filter. @@ -734,8 +733,8 @@ def plot(self): # pragma: no cover @fill_doc def pick( self, - picks: Optional[Union[str, list[str], int, list[int], ScalarIntArray]] = None, - exclude: Union[str, list[str], int, list[int], ScalarIntArray] = (), + picks: str | list[str] | int | list[int] | ScalarIntArray | None = None, + exclude: str | list[str] | int | list[int] | ScalarIntArray = (), ) -> BaseStream: """Pick a subset of channels. @@ -777,10 +776,10 @@ def record(self): # pragma: no cover @fill_doc def rename_channels( self, - mapping: Union[dict[str, str], Callable], + mapping: dict[str, str] | Callable, allow_duplicates: bool = False, *, - verbose: Optional[Union[bool, str, int]] = None, + verbose: bool | str | int | None = None, ) -> BaseStream: """Rename channels. @@ -829,7 +828,7 @@ def set_channel_types( mapping: dict[str, str], *, on_unit_change: str = "warn", - verbose: Optional[Union[bool, str, int]] = None, + verbose: bool | str | int | None = None, ) -> BaseStream: """Define the sensor type of channels. @@ -864,7 +863,7 @@ def set_channel_types( ) return self - def set_channel_units(self, mapping: dict[str, Union[str, int]]) -> BaseStream: + def set_channel_units(self, mapping: dict[str, str | int]) -> BaseStream: """Define the channel unit multiplication factor. The unit itself is defined by the sensor type. Use @@ -896,8 +895,8 @@ def set_channel_units(self, mapping: dict[str, Union[str, int]]) -> BaseStream: @fill_doc def set_eeg_reference( self, - ref_channels: Union[str, list[str], tuple[str]], - ch_type: Union[str, list[str], tuple[str]] = "eeg", + ref_channels: str | list[str] | tuple[str, ...], + ch_type: str | list[str] | tuple[str, ...] = "eeg", ) -> BaseStream: """Specify which reference to use for EEG-like data. @@ -969,7 +968,7 @@ def set_eeg_reference( return self def set_meas_date( - self, meas_date: Optional[Union[datetime, float, tuple[float]]] + self, meas_date: datetime | float | tuple[float, float] | None ) -> BaseStream: """Set the measurement start date. @@ -1001,12 +1000,12 @@ def set_meas_date( @fill_doc def set_montage( self, - montage: Optional[Union[str, DigMontage]], + montage: str | DigMontage | None, match_case: bool = True, - match_alias: Union[bool, dict[str, str]] = False, + match_alias: bool | dict[str, str] = False, on_missing: str = "raise", *, - verbose: Optional[Union[bool, str, int]] = None, + verbose: bool | str | int | None = None, ) -> BaseStream: """Set %(montage_types)s channel positions and digitization points. @@ -1158,7 +1157,7 @@ def _submit_acquisition_job(self) -> None: # ---------------------------------------------------------------------------------- @property - def compensation_grade(self) -> Optional[int]: + def compensation_grade(self) -> int | None: """The current gradient compensation grade. :type: :class:`int` | None @@ -1198,7 +1197,7 @@ def connected(self) -> bool: return True @property - def dtype(self) -> Optional[DTypeLike]: + def dtype(self) -> DTypeLike | None: """Channel format of the stream.""" return getattr(self._buffer, "dtype", None) diff --git a/mne_lsl/stream/base.pyi b/mne_lsl/stream/base.pyi index b5c5434d0..81fb15dda 100644 --- a/mne_lsl/stream/base.pyi +++ b/mne_lsl/stream/base.pyi @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod -from collections.abc import Generator +from collections.abc import Callable, Generator from datetime import datetime as datetime -from typing import Any, Callable +from typing import Any import numpy as np from _typeshed import Incomplete @@ -93,7 +93,7 @@ class BaseStream(ABC, ContainsMixin, SetChannelsMixin): def add_reference_channels( self, ref_channels: str | list[str] | tuple[str], - ref_units: str | int | list[str | int] | tuple[str | int] | None = None, + ref_units: str | int | list[str | int] | tuple[str | int, ...] | None = None, ) -> BaseStream: """Add EEG reference channels to data that consists of all zeros. @@ -218,7 +218,7 @@ class BaseStream(ABC, ContainsMixin, SetChannelsMixin): The stream instance modified in-place. """ - def del_filter(self, idx: int | list[int] | tuple[int] | str = "all") -> None: + def del_filter(self, idx: int | list[int] | tuple[int, ...] | str = "all") -> None: """Remove a filter from the list of applied filters. Parameters @@ -235,7 +235,7 @@ class BaseStream(ABC, ContainsMixin, SetChannelsMixin): a step response steady-state. """ - def drop_channels(self, ch_names: str | list[str] | tuple[str]) -> BaseStream: + def drop_channels(self, ch_names: str | list[str] | tuple[str, ...]) -> BaseStream: """Drop channel(s). Parameters diff --git a/mne_lsl/stream/epochs.py b/mne_lsl/stream/epochs.py index 2a009fa57..d12c7a579 100644 --- a/mne_lsl/stream/epochs.py +++ b/mne_lsl/stream/epochs.py @@ -28,8 +28,6 @@ from .base import BaseStream if TYPE_CHECKING: - from typing import Optional, Union - from mne import Info from numpy.typing import NDArray @@ -130,18 +128,18 @@ def __init__( self, stream: BaseStream, bufsize: int, - event_id: Optional[Union[int, dict[str, int]]], - event_channels: Union[str, list[str]], - event_stream: Optional[BaseStream] = None, + event_id: int | dict[str, int] | None, + event_channels: str | list[str], + event_stream: BaseStream | None = None, tmin: float = -0.2, tmax: float = 0.5, - baseline: Optional[tuple[Optional[float], Optional[float]]] = (None, 0), - picks: Optional[Union[str, list[str], int, list[int], ScalarIntArray]] = None, - reject: Optional[dict[str, float]] = None, - flat: Optional[dict[str, float]] = None, - reject_tmin: Optional[float] = None, - reject_tmax: Optional[float] = None, - detrend: Optional[Union[int, str]] = None, + baseline: tuple[float | None, float | None] | None = (None, 0), + picks: str | list[str] | int | list[int] | ScalarIntArray | None = None, + reject: dict[str, float] | None = None, + flat: dict[str, float] | None = None, + reject_tmin: float | None = None, + reject_tmax: float | None = None, + detrend: int | str | None = None, ) -> None: check_type(stream, (BaseStream,), "stream") if not stream.connected or stream._info["sfreq"] == 0: @@ -375,9 +373,9 @@ def disconnect(self) -> EpochsStream: @fill_doc def get_data( self, - n_epochs: Optional[int] = None, - picks: Optional[Union[str, list[str], int, list[int], ScalarIntArray]] = None, - exclude: Union[str, list[str], tuple[str]] = "bads", + n_epochs: int | None = None, + picks: str | list[str] | int | list[int] | ScalarIntArray | None = None, + exclude: str | list[str] | tuple[str, ...] = "bads", ) -> ScalarArray: """Retrieve the latest epochs from the buffer. @@ -658,7 +656,7 @@ def times(self) -> NDArray[np.float64]: def _check_event_channels( event_channels: list[str], stream: BaseStream, - event_stream: Optional[BaseStream], + event_stream: BaseStream | None, ) -> None: """Check that the event channels are valid.""" for elt in event_channels: @@ -698,8 +696,8 @@ def _check_event_channels( def _ensure_event_id( - event_id: Optional[Union[int, dict[str, int]]], event_stream: Optional[BaseStream] -) -> Optional[dict[str, int]]: + event_id: int | dict[str, int] | None, event_stream: BaseStream | None +) -> dict[str, int] | None: """Ensure event_ids is a dictionary or None.""" check_type(event_id, (None, int, dict), "event_id") if event_id is None: @@ -741,7 +739,7 @@ def _ensure_event_id( def _check_baseline( - baseline: Optional[tuple[Optional[float], Optional[float]]], + baseline: tuple[float | None, float | None] | None, tmin: float, tmax: float, ) -> None: @@ -766,7 +764,7 @@ def _check_baseline( def _check_reject_flat( - reject: Optional[dict[str, float]], flat: Optional[dict[str, float]], info: Info + reject: dict[str, float] | None, flat: dict[str, float] | None, info: Info ) -> None: """Check that the PTP rejection dictionaries are valid.""" check_type(reject, (dict, None), "reject") @@ -805,7 +803,7 @@ def _check_reject_flat( def _check_reject_tmin_tmax( - reject_tmin: Optional[float], reject_tmax: Optional[float], tmin: float, tmax: float + reject_tmin: float | None, reject_tmax: float | None, tmin: float, tmax: float ) -> None: """Check that the rejection time window is valid.""" check_type(reject_tmin, ("numeric", None), "reject_tmin") @@ -831,7 +829,7 @@ def _check_reject_tmin_tmax( ) -def _ensure_detrend_str(detrend: Optional[Union[int, str]]) -> Optional[str]: +def _ensure_detrend_str(detrend: int | str | None) -> str | None: """Ensure detrend is an integer.""" if detrend is None: return None @@ -854,10 +852,10 @@ def _find_events_in_stim_channels( sfreq: float, *, output: str = "onset", - consecutive: Union[bool, str] = "increasing", + consecutive: bool | str = "increasing", min_duration: float = 0, shortest_event: int = 2, - mask: Optional[int] = None, + mask: int | None = None, uint_cast: bool = False, mask_type: str = "and", initial_event: bool = False, @@ -865,7 +863,7 @@ def _find_events_in_stim_channels( """Find events in stim channels.""" min_samples = min_duration * sfreq events_list = [] - for d, ch_name in zip(data, event_channels): + for d, ch_name in zip(data, event_channels, strict=True): events = find_events( d[np.newaxis, :], first_samp=0, @@ -897,11 +895,11 @@ def _find_events_in_stim_channels( def _prune_events( events: NDArray[np.int64], - event_id: Optional[dict[str, int]], + event_id: dict[str, int] | None, buffer_size: int, ts: NDArray[np.float64], - last_ts: Optional[float], - ts_events: Optional[NDArray[np.float64]], + last_ts: float | None, + ts_events: NDArray[np.float64] | None, tmin_shift: float, ) -> NDArray[np.int64]: """Prune events based on criteria and buffer size.""" @@ -932,12 +930,12 @@ def _prune_events( def _process_data( data: ScalarArray, # array of shape (n_epochs, n_samples, n_channels) - baseline: Optional[tuple[Optional[float], Optional[float]]], - reject: Optional[dict[str, float]], - flat: Optional[dict[str, float]], - reject_tmin: Optional[float], - reject_tmax: Optional[float], - detrend_type: Optional[str], + baseline: tuple[float | None, float | None] | None, + reject: dict[str, float] | None, + flat: dict[str, float] | None, + reject_tmin: float | None, + reject_tmax: float | None, + detrend_type: str | None, times: NDArray[np.float64], ch_idx_by_type: dict[str, list[int]], ) -> ScalarArray: diff --git a/mne_lsl/stream/stream_lsl.py b/mne_lsl/stream/stream_lsl.py index 0a4b73a9a..e07020d5b 100644 --- a/mne_lsl/stream/stream_lsl.py +++ b/mne_lsl/stream/stream_lsl.py @@ -24,7 +24,6 @@ if TYPE_CHECKING: from collections.abc import Sequence - from typing import Optional, Union from mne_lsl.lsl.stream_info import _BaseStreamInfo @@ -55,9 +54,9 @@ def __init__( self, bufsize: float, *, - name: Optional[str] = None, - stype: Optional[str] = None, - source_id: Optional[str] = None, + name: str | None = None, + stype: str | None = None, + source_id: str | None = None, ): super().__init__(bufsize) check_type(name, (str, None), "name") @@ -101,8 +100,8 @@ def connect( self, acquisition_delay: float = 0.001, *, - processing_flags: Optional[Union[str, Sequence[str]]] = None, - timeout: Optional[float] = 2, + processing_flags: str | Sequence[str] | None = None, + timeout: float | None = 2, ) -> StreamLSL: """Connect to the LSL stream and initiate data collection in the buffer. @@ -338,7 +337,7 @@ def connected(self) -> bool: return False @property - def name(self) -> Optional[str]: + def name(self) -> str | None: """Name of the LSL stream. :type: :class:`str` | None @@ -346,7 +345,7 @@ def name(self) -> Optional[str]: return self._name @property - def sinfo(self) -> Optional[_BaseStreamInfo]: + def sinfo(self) -> _BaseStreamInfo | None: """StreamInfo of the connected stream. :type: :class:`~mne_lsl.lsl.StreamInfo` | None @@ -354,7 +353,7 @@ def sinfo(self) -> Optional[_BaseStreamInfo]: return self._sinfo @property - def stype(self) -> Optional[str]: + def stype(self) -> str | None: """Type of the LSL stream. :type: :class:`str` | None @@ -362,7 +361,7 @@ def stype(self) -> Optional[str]: return self._stype @property - def source_id(self) -> Optional[str]: + def source_id(self) -> str | None: """ID of the source of the LSL stream. :type: :class:`str` | None diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index f875149ac..ec985a5b1 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -11,7 +11,7 @@ from mne_lsl.stream._filters import StreamFilter, create_filter, ensure_sos_iir_params if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any @pytest.fixture(scope="module") @@ -69,7 +69,7 @@ def filters(iir_params: dict[str, Any], sfreq: float) -> list[StreamFilter]: h_freqs = (40, 15, None) picks = (np.arange(0, 10), np.arange(10, 20), np.arange(20, 30)) filters = list() - for k, (lfq, hfq, picks_) in enumerate(zip(l_freqs, h_freqs, picks)): + for k, (lfq, hfq, picks_) in enumerate(zip(l_freqs, h_freqs, picks, strict=True)): filt = create_filter( sfreq=sfreq, l_freq=lfq, @@ -122,8 +122,8 @@ def test_StreamFilter_repr(filters: list[StreamFilter]): def test_create_filter( iir_params: dict[str, Any], sfreq: float, - l_freq: Optional[float], - h_freq: Optional[float], + l_freq: float | None, + h_freq: float | None, ): """Test create_filter conformity with MNE.""" filter1 = create_filter( diff --git a/mne_lsl/utils/_checks.py b/mne_lsl/utils/_checks.py index 4f71aced9..e43932606 100644 --- a/mne_lsl/utils/_checks.py +++ b/mne_lsl/utils/_checks.py @@ -4,14 +4,14 @@ import operator import os from pathlib import Path -from typing import Any, Optional +from typing import Any import numpy as np from ._docs import fill_doc -def ensure_int(item: Any, item_name: Optional[str] = None) -> int: +def ensure_int(item: Any, item_name: str | None = None) -> int: """Ensure a variable is an integer. Parameters @@ -67,7 +67,7 @@ def __instancecheck__(cls, other: Any) -> bool: } -def check_type(item: Any, types: tuple, item_name: Optional[str] = None) -> None: +def check_type(item: Any, types: tuple, item_name: str | None = None) -> None: """Check that item is an instance of types. Parameters @@ -123,8 +123,8 @@ def check_type(item: Any, types: tuple, item_name: Optional[str] = None) -> None def check_value( item: Any, allowed_values: tuple, - item_name: Optional[str] = None, - extra: Optional[str] = None, + item_name: str | None = None, + extra: str | None = None, ) -> None: """Check the value of a parameter against a list of valid options. diff --git a/mne_lsl/utils/_docs.py b/mne_lsl/utils/_docs.py index c423418ee..e3f666a4a 100644 --- a/mne_lsl/utils/_docs.py +++ b/mne_lsl/utils/_docs.py @@ -5,7 +5,7 @@ """ import sys -from typing import Callable +from collections.abc import Callable from mne.utils.docs import docdict as docdict_mne diff --git a/mne_lsl/utils/_docs.pyi b/mne_lsl/utils/_docs.pyi index 0cf2b0fa2..322c52253 100644 --- a/mne_lsl/utils/_docs.pyi +++ b/mne_lsl/utils/_docs.pyi @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable docdict: dict[str, str] _KEYS_MNE: tuple[str, ...] diff --git a/mne_lsl/utils/_fixes.py b/mne_lsl/utils/_fixes.py index 702e3227a..d84d90c66 100644 --- a/mne_lsl/utils/_fixes.py +++ b/mne_lsl/utils/_fixes.py @@ -9,8 +9,6 @@ from mne.utils import check_version if TYPE_CHECKING: - from typing import Optional, Union - from .._typing import ScalarArray @@ -34,11 +32,11 @@ def __getattr__(self, name): # noqa: D105 def find_events( data: ScalarArray, first_samp: int, - verbose: Optional[Union[bool, str, int]] = None, + verbose: bool | str | int | None = None, output: str = None, - consecutive: Union[bool, str] = None, + consecutive: bool | str = None, min_samples: float = None, - mask: Optional[int] = None, + mask: int | None = None, uint_cast: bool = None, mask_type: str = None, initial_event: bool = None, diff --git a/mne_lsl/utils/_imports.py b/mne_lsl/utils/_imports.py index 6821a6f48..7e3f7cf0d 100644 --- a/mne_lsl/utils/_imports.py +++ b/mne_lsl/utils/_imports.py @@ -5,12 +5,12 @@ from __future__ import annotations -import importlib +from importlib import import_module +from importlib.util import find_spec from typing import TYPE_CHECKING if TYPE_CHECKING: from types import ModuleType - from typing import Optional # A mapping from import name to package name (on PyPI) when the package name # is different. @@ -29,7 +29,7 @@ def import_optional_dependency( name: str, extra: str = "", raise_error: bool = True, -) -> Optional[ModuleType]: +) -> ModuleType | None: """Import an optional dependency. By default, if a dependency is missing an ImportError with a nice message will be @@ -48,22 +48,18 @@ def import_optional_dependency( Returns ------- - module : Optional[ModuleType] + module : Module | None The imported module when found. None is returned when the package is not found and raise_error is False. """ package_name = _INSTALL_MAPPING.get(name) install_name = package_name if package_name is not None else name - - try: - module = importlib.import_module(name) - except ImportError: + if find_spec(name) is None: if raise_error: raise ImportError( - f"Missing optional dependency '{install_name}'. {extra} " - f"Use pip or conda to install {install_name}." + f"Missing optional dependency '{install_name}'. {extra} Use pip or " + f"conda to install {install_name}." ) else: return None - - return module + return import_module(name) diff --git a/mne_lsl/utils/_tests.py b/mne_lsl/utils/_tests.py index 827fd98ab..41f2672ae 100644 --- a/mne_lsl/utils/_tests.py +++ b/mne_lsl/utils/_tests.py @@ -10,7 +10,6 @@ if TYPE_CHECKING: from pathlib import Path - from typing import Union from mne import Info from mne.io import BaseRaw @@ -18,7 +17,7 @@ from .._typing import ScalarArray -def sha256sum(fname: Union[str, Path]) -> str: +def sha256sum(fname: str | Path) -> str: """Efficiently hash a file.""" h = hashlib.sha256() b = bytearray(128 * 1024) @@ -112,7 +111,7 @@ def compare_infos(info1: Info, info2: Info) -> None: assert len(info1["projs"]) == len(info2["projs"]) projs1 = sorted(info1["projs"], key=lambda x: x["desc"]) projs2 = sorted(info2["projs"], key=lambda x: x["desc"]) - for proj1, proj2 in zip(projs1, projs2): + for proj1, proj2 in zip(projs1, projs2, strict=True): assert proj1["desc"] == proj2["desc"] assert proj1["kind"] == proj2["kind"] assert proj1["data"]["nrow"] == proj2["data"]["nrow"] @@ -125,7 +124,7 @@ def compare_infos(info1: Info, info2: Info) -> None: assert len(info1["dig"]) == len(info2["dig"]) digs1 = sorted(info1["dig"], key=lambda x: (x["kind"], x["ident"])) digs2 = sorted(info2["dig"], key=lambda x: (x["kind"], x["ident"])) - for dig1, dig2 in zip(digs1, digs2): + for dig1, dig2 in zip(digs1, digs2, strict=True): assert dig1["kind"] == dig2["kind"] assert dig1["ident"] == dig2["ident"] assert dig1["coord_frame"] == dig2["coord_frame"] diff --git a/mne_lsl/utils/config.py b/mne_lsl/utils/config.py index 5986a76c0..4467c42f7 100644 --- a/mne_lsl/utils/config.py +++ b/mne_lsl/utils/config.py @@ -1,8 +1,9 @@ import platform import sys +from collections.abc import Callable from functools import lru_cache, partial from importlib.metadata import requires, version -from typing import IO, Callable, Optional +from typing import IO import psutil from packaging.requirements import Requirement @@ -11,7 +12,7 @@ from .logs import _use_log_level -def sys_info(fid: Optional[IO] = None, developer: bool = False): +def sys_info(fid: IO | None = None, developer: bool = False): """Print the system information for debugging. Parameters @@ -160,7 +161,7 @@ def _list_dependencies_info( @lru_cache(maxsize=1) -def _get_gpu_info() -> tuple[Optional[str], Optional[str]]: +def _get_gpu_info() -> tuple[str | None, str | None]: """Get the GPU information.""" try: from pyvista import GPUInfo diff --git a/mne_lsl/utils/config.pyi b/mne_lsl/utils/config.pyi index d25255b43..53c1088ff 100644 --- a/mne_lsl/utils/config.pyi +++ b/mne_lsl/utils/config.pyi @@ -1,4 +1,5 @@ -from typing import IO, Callable +from collections.abc import Callable +from typing import IO from packaging.requirements import Requirement diff --git a/mne_lsl/utils/logs.py b/mne_lsl/utils/logs.py index 16a98f387..5f342d7d8 100644 --- a/mne_lsl/utils/logs.py +++ b/mne_lsl/utils/logs.py @@ -14,8 +14,8 @@ from ._fixes import WrapStdOut if TYPE_CHECKING: + from collections.abc import Callable from logging import Logger - from typing import Callable, Optional, Union _PACKAGE: str = __package__.split(".")[0] @@ -24,9 +24,7 @@ @fill_doc def _init_logger( *, - verbose: Optional[Union[bool, str, int]] = os.getenv( - "MNE_LSL_LOG_LEVEL", "WARNING" - ), + verbose: bool | str | int | None = os.getenv("MNE_LSL_LOG_LEVEL", "WARNING"), ) -> Logger: """Initialize a logger. @@ -54,11 +52,11 @@ def _init_logger( def add_file_handler( - fname: Union[str, Path], + fname: str | Path, mode: str = "a", - encoding: Optional[str] = None, + encoding: str | None = None, *, - verbose: Optional[Union[bool, str, int]] = None, + verbose: bool | str | int | None = None, ) -> None: """Add a file handler to the logger. @@ -85,7 +83,7 @@ def add_file_handler( @fill_doc -def set_log_level(verbose: Optional[Union[bool, str, int]]) -> None: +def set_log_level(verbose: bool | str | int | None) -> None: """Set the log level for the logger. Parameters @@ -172,12 +170,12 @@ class _use_log_level: def __init__( self, - verbose: Optional[Union[bool, str, int]] = None, - logger_obj: Optional[Logger] = None, + verbose: bool | str | int | None = None, + logger_obj: Logger | None = None, ): self._logger: Logger = logger_obj if logger_obj is not None else logger self._old_level: int = self._logger.level - self._level: Optional[int] = None if verbose is None else check_verbose(verbose) + self._level: int | None = None if verbose is None else check_verbose(verbose) def __enter__(self): if self._level is not None: @@ -193,7 +191,7 @@ def warn( message: str, category: Warning = RuntimeWarning, module: str = _PACKAGE, - ignore_namespaces: Union[tuple[str, ...] | list[str]] = (_PACKAGE,), + ignore_namespaces: tuple[str, ...] | list[str] = (_PACKAGE,), ) -> None: """Emit a warning with trace outside the requested namespace. diff --git a/mne_lsl/utils/logs.pyi b/mne_lsl/utils/logs.pyi index 9cdc2771a..e97b28d28 100644 --- a/mne_lsl/utils/logs.pyi +++ b/mne_lsl/utils/logs.pyi @@ -1,7 +1,7 @@ import logging +from collections.abc import Callable from logging import Logger from pathlib import Path -from typing import Callable from _typeshed import Incomplete diff --git a/mne_lsl/utils/meas_info.py b/mne_lsl/utils/meas_info.py index 110160982..2ca8d6a4d 100644 --- a/mne_lsl/utils/meas_info.py +++ b/mne_lsl/utils/meas_info.py @@ -17,7 +17,7 @@ from .logs import logger, warn if TYPE_CHECKING: - from typing import Any, Optional, Union + from typing import Any from mne import Info @@ -49,7 +49,7 @@ def create_info( n_channels: int, sfreq: float, stype: str, - desc: Optional[Union[_BaseStreamInfo, dict[str, Any]]], + desc: _BaseStreamInfo | dict[str, Any] | None, ) -> Info: """Create a minimal :class:`mne.Info` object from an LSL stream attributes. @@ -118,7 +118,7 @@ def create_info( info["sfreq"] = sfreq info["lowpass"] = 0.0 info["highpass"] = 0.0 - for ch, ch_unit in zip(info["chs"], ch_units): + for ch, ch_unit in zip(info["chs"], ch_units, strict=True): ch["unit_mul"] = ch_unit # add manufacturer information if available info["device_info"] = dict() @@ -172,7 +172,7 @@ def _create_default_info( # --------------------- Functions to read from a description sinfo --------------------- def _read_desc_sinfo( n_channels: int, stype: str, desc: _BaseStreamInfo -) -> tuple[list[str], list[str], list[int], Optional[str]]: +) -> tuple[list[str], list[str], list[int], str | None]: """Read channel information from a StreamInfo. If the StreamInfo is retrieved by resolve_streams, the description will be empty. @@ -203,7 +203,7 @@ def _read_desc_sinfo( try: ch_units = list() - for ch_type, ch_unit in zip(ch_types, desc.get_channel_units()): + for ch_type, ch_unit in zip(ch_types, desc.get_channel_units(), strict=True): ch_unit = ch_unit.lower().strip() fiff_unit = _CH_TYPES_DICT[ch_type]["unit"] if fiff_unit in _HUMAN_UNITS: @@ -231,7 +231,7 @@ def _read_desc_sinfo( # --------------------- Functions to read from a description dict ---------------------- def _read_desc_dict( n_channels: int, stype: str, desc: dict[str, Any] -) -> tuple[list[str], list[str], list[int], Optional[str]]: +) -> tuple[list[str], list[str], list[int], str | None]: """Read channel information from a description dictionary. A dictionary is returned from loading an XDF file. @@ -301,7 +301,7 @@ def _safe_get(channel, item, default) -> str: # ----------------------------- Functions to edit an Info ------------------------------ -def _set_channel_units(info: Info, mapping: dict[str, Union[str, int]]) -> None: +def _set_channel_units(info: Info, mapping: dict[str, str | int]) -> None: """Set the channel unit multiplication factor.""" check_type(mapping, (dict,), "mapping") mapping_idx = dict() # to avoid overwriting the input dictionary diff --git a/mne_lsl/utils/tests/test_logs.py b/mne_lsl/utils/tests/test_logs.py index 9fd4e9f94..d443e1c9d 100644 --- a/mne_lsl/utils/tests/test_logs.py +++ b/mne_lsl/utils/tests/test_logs.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: from pathlib import Path - from typing import Optional, Union def test_default_log_level(caplog: pytest.LogCaptureFixture): @@ -68,7 +67,7 @@ def test_verbose(caplog: pytest.LogCaptureFixture): # function @verbose - def foo(verbose: Optional[Union[bool, str, int]] = None): + def foo(verbose: bool | str | int | None = None): """Foo function.""" logger.debug("101") @@ -94,12 +93,12 @@ def __init__(self): pass @verbose - def foo(self, verbose: Optional[Union[bool, str, int]] = None): + def foo(self, verbose: bool | str | int | None = None): logger.debug("101") @staticmethod @verbose - def foo2(verbose: Optional[Union[bool, str, int]] = None): + def foo2(verbose: bool | str | int | None = None): logger.debug("101") foo = Foo() diff --git a/mne_lsl/utils/tests/test_meas_info.py b/mne_lsl/utils/tests/test_meas_info.py index 24131d174..af12fb71c 100644 --- a/mne_lsl/utils/tests/test_meas_info.py +++ b/mne_lsl/utils/tests/test_meas_info.py @@ -15,7 +15,7 @@ def test_valid_info(): # nested desc = dict(channels=list()) desc["channels"].append(dict(channel=list())) - for ch_name, ch_type, ch_unit in zip(ch_names, ch_types, ch_units): + for ch_name, ch_type, ch_unit in zip(ch_names, ch_types, ch_units, strict=True): desc["channels"][0]["channel"].append( dict(label=[ch_name], unit=[ch_unit], type=[ch_type]) ) @@ -31,7 +31,7 @@ def test_valid_info(): # non-nested desc = dict(channels=list()) desc["channels"].append(dict(channel=list())) - for ch_name, ch_type, ch_unit in zip(ch_names, ch_types, ch_units): + for ch_name, ch_type, ch_unit in zip(ch_names, ch_types, ch_units, strict=True): desc["channels"][0]["channel"].append( dict(label=ch_name, unit=ch_unit, type=ch_type) ) @@ -70,7 +70,7 @@ def test_valid_info(): # nested desc = dict(channels=list()) desc["channels"].append(dict(channel=list())) - for ch_name, ch_type, ch_unit in zip(ch_names, ch_types, ch_units): + for ch_name, ch_type, ch_unit in zip(ch_names, ch_types, ch_units, strict=True): desc["channels"][0]["channel"].append( dict(label=[ch_name], unit=[ch_unit], type=[ch_type]) ) @@ -91,7 +91,7 @@ def test_invalid_info(): desc = dict(channels=list()) desc["channels"].append(dict(channel=list())) - for ch_name, ch_type, ch_unit in zip(ch_names, ch_types, ch_units): + for ch_name, ch_type, ch_unit in zip(ch_names, ch_types, ch_units, strict=True): desc["channels"][0]["channel"].append( dict(label=[ch_name], unit=[ch_unit], type=[ch_type]) ) @@ -111,7 +111,7 @@ def test_invalid_info(): # nested desc = dict(channels=list()) desc["channels"].append(dict(channel=list())) - for ch_name, ch_type, ch_unit in zip(ch_names, ch_types, ch_units): + for ch_name, ch_type, ch_unit in zip(ch_names, ch_types, ch_units, strict=True): desc["channels"][0]["channel"].append( dict(label=[ch_name], unit=[ch_unit], type=[ch_type]) ) @@ -147,7 +147,7 @@ def test_manufacturer(): # nested desc = dict(channels=list(), manufacturer=list()) desc["channels"].append(dict(channel=list())) - for ch_name, ch_type in zip(ch_names, ch_types): + for ch_name, ch_type in zip(ch_names, ch_types, strict=True): desc["channels"][0]["channel"].append( dict(label=[ch_name], unit=["uv"], type=[ch_type]) ) @@ -159,7 +159,7 @@ def test_manufacturer(): # not nested desc = dict(channels=list(), manufacturer="101") desc["channels"].append(dict(channel=list())) - for ch_name, ch_type in zip(ch_names, ch_types): + for ch_name, ch_type in zip(ch_names, ch_types, strict=True): desc["channels"][0]["channel"].append( dict(label=[ch_name], unit=["uv"], type=[ch_type]) ) diff --git a/mne_lsl/utils/tests/test_test.py b/mne_lsl/utils/tests/test_test.py index e3efcd0e3..7ba2d278c 100644 --- a/mne_lsl/utils/tests/test_test.py +++ b/mne_lsl/utils/tests/test_test.py @@ -74,6 +74,7 @@ def test_compare_infos(raw): for param, value in zip( ("kind", "coil_type", "loc", "unit", "unit_mul", "ch_name", "coord_frame"), (202, 1, np.ones(12), 107, -6, "101", 0), + strict=True, ): info2 = info.copy() compare_infos(info, info2) diff --git a/pyproject.toml b/pyproject.toml index 193fe45cd..0924fc37e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ classifiers = [ 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', - 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.13', 'Topic :: Scientific/Engineering', 'Topic :: Software Development', ] @@ -54,7 +54,7 @@ maintainers = [ ] name = 'mne_lsl' readme = 'README.md' -requires-python = '>=3.9' +requires-python = '>=3.10' version = '1.7.0.dev0' [project.optional-dependencies] @@ -152,7 +152,6 @@ select = ['A', 'B', 'D', 'E', 'F', 'G', 'I', 'LOG', 'NPY', 'PIE', 'PT', 'T20', ' 'D100', # 'Missing docstring in public module' 'D104', # 'Missing docstring in public package' 'D107', # 'Missing docstring in __init__' - 'UP007', # 'Use `X | Y` for type annotations', requires python 3.10 ] '*.pyi' = ['D', 'E501', 'F811'] '*/stream_viewer/backends/*' = ['D401'] diff --git a/tools/stubgen.py b/tools/stubgen.py index 22c619a53..408c4bdb3 100644 --- a/tools/stubgen.py +++ b/tools/stubgen.py @@ -43,7 +43,7 @@ objects = [ node for node in module_ast.body - if isinstance(node, (ast.ClassDef, ast.FunctionDef)) + if isinstance(node, (ast.ClassDef | ast.FunctionDef)) ] for node in objects: docstring = getattr(module, node.name).__doc__ diff --git a/tutorials/20_player_annotations.py b/tutorials/20_player_annotations.py index b3373624f..20c6c3ef2 100644 --- a/tutorials/20_player_annotations.py +++ b/tutorials/20_player_annotations.py @@ -143,7 +143,7 @@ data_annotations, ts_annotations = stream_annotations.get_data( winsize=stream_annotations.n_new_samples ) - for sample, time in zip(data_annotations.T, ts_annotations): + for sample, time in zip(data_annotations.T, ts_annotations, strict=True): k = np.where(sample != 0)[0][0] # find the annotation ax.axvspan( time,