From 16f2b65017b670bfd39dcc93c3a4c50efda70497 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Tue, 27 Feb 2024 15:36:31 +0100 Subject: [PATCH 01/69] add arguments and description of the filtering method --- mne_lsl/stream/_base.py | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index ec9d62906..ca9c67437 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -32,7 +32,7 @@ if TYPE_CHECKING: from datetime import datetime - from typing import Callable, Optional, Union + from typing import Any, Callable, Optional, Union from mne import Info from mne.channels import DigMontage @@ -333,8 +333,37 @@ def drop_channels(self, ch_names: Union[str, list[str], tuple[str]]) -> BaseStre self._pick(picks) return self - def filter(self) -> BaseStream: # noqa: A003 - """Filter the stream. Not implemented. + @fill_doc + def filter( + self, + l_freq: Optional[float], + h_freq: Optional[float], + picks, + iir_params: Optional[dict[str, Any]], + ) -> BaseStream: # noqa: A003 + """Filter the stream with an IIR causal filter. + + Once a filter is applied, the buffer is updated in real-time with the filtered + data. It is not possible to remove an applied filter. It is possible to apply + more than one filter. + + .. code-block:: python + + stream = Stream(2.0).connect() + stream.filter(1.0, 40.0, picks="eeg") + stream.filter(1.0, 15.0, picks="ecg").filter(0.1, 5, picks="EDA") + + Parameters + ---------- + l_freq : float | None + The lower cutoff frequency. If None, the buffer is only low-passed. + h_freq : float | None + The higher cutoff frequency. If None, the buffer is only high-passed. + %(picks_all)s + iir_params : dict | None + Dictionary of parameters to use for IIR filtering. If None, a 4th order + Butterworth will be used. For more information, see + :func:`mne.filter.construct_iir_filter`. Returns ------- From 667b343df9b6a85c2048944109cf8628d3b93523 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Tue, 27 Feb 2024 15:51:52 +0100 Subject: [PATCH 02/69] fix picks/docstrings --- mne_lsl/stream/_base.py | 8 ++++---- mne_lsl/utils/_docs.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index ca9c67437..000843cb2 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -371,7 +371,7 @@ def filter( The stream instance modified in-place. """ self._check_connected_and_regular_sampling("filter()") - raise NotImplementedError + picks = _picks_to_idx(self._info, picks, "all", "bads", allow_empty=False) @copy_doc(ContainsMixin.get_channel_types) def get_channel_types( @@ -405,7 +405,7 @@ def get_channel_units( self._check_connected(name="get_channel_units()") check_type(only_data_chs, (bool,), "only_data_chs") none = "data" if only_data_chs else "all" - picks = _picks_to_idx(self._info, picks, none, (), allow_empty=False) + picks = _picks_to_idx(self._info, picks, none, "bads", allow_empty=False) channel_units = list() for idx in picks: channel_units.append( @@ -466,7 +466,7 @@ def get_data( # 8.68 µs ± 113 ns per loop # >>> %timeit _picks_to_idx(raw.info, None) # 253 µs ± 1.22 µs per loop - picks = _picks_to_idx(self._info, picks, none="all") + picks = _picks_to_idx(self._info, picks, none="all", exclude="bads") self._n_new_samples = 0 # reset the number of new samples return self._buffer[-n_samples:, picks].T, self._timestamps[-n_samples:] except Exception: @@ -499,7 +499,7 @@ def pick(self, picks, exclude=()) -> BaseStream: Parameters ---------- - %(picks_all)s + %(picks_base)s all channels. exclude : str | list of str Set of channels to exclude, only used when picking is based on types, e.g. ``exclude='bads'`` when ``picks="meg"``. diff --git a/mne_lsl/utils/_docs.py b/mne_lsl/utils/_docs.py index 35a55bed6..89e39814e 100644 --- a/mne_lsl/utils/_docs.py +++ b/mne_lsl/utils/_docs.py @@ -23,6 +23,7 @@ "montage_types", "on_missing_montage", "picks_all", + "picks_base", "ref_channels", ) From ead65c3d103c0dfd9a69656003d1a20332347185 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Tue, 27 Feb 2024 16:26:43 +0100 Subject: [PATCH 03/69] add filter validation --- mne_lsl/stream/_base.py | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index 000843cb2..dfb752440 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -5,11 +5,13 @@ from math import ceil from threading import Timer from typing import TYPE_CHECKING +from warnings import warn import numpy as np from mne import pick_info, pick_types from mne.channels import rename_channels -from mne.utils import check_version +from mne.filter import construct_iir_filter +from mne.utils import check_version, use_log_level if check_version("mne", "1.6"): from mne._fiff.constants import FIFF, _ch_unit_mul_named @@ -339,7 +341,9 @@ def filter( l_freq: Optional[float], h_freq: Optional[float], picks, - iir_params: Optional[dict[str, Any]], + iir_params: Optional[dict[str, Any]] = dict( + order=4, ftype="butter", output="sos" + ), ) -> BaseStream: # noqa: A003 """Filter the stream with an IIR causal filter. @@ -365,13 +369,42 @@ def filter( Butterworth will be used. For more information, see :func:`mne.filter.construct_iir_filter`. + .. note:: + + The output ``sos`` must be used. The ``ba`` output is not supported. + Returns ------- stream : instance of ``Stream`` The stream instance modified in-place. """ self._check_connected_and_regular_sampling("filter()") + # validate the arguments picks = _picks_to_idx(self._info, picks, "all", "bads", allow_empty=False) + if ("output" in iir_params and iir_params["output"] != "sos") or all( + key in iir_params for key in ("a", "b") + ): + warn( + "Only 'sos' output is supported for real-time filtering. The filter " + "output will be automatically changed. Please set " + "iir_params=dict(output='sos', ...) in your call to filter().", + RuntimeWarning, + stacklevel=2, + ) + for key in ("a", "b"): + if key in iir_params: + del iir_params[key] + iir_params["output"] = "sos" + # construt an IIR filter + with use_log_level(logger.level): # ensure MNE log is set to the same level + iir_params = construct_iir_filter( + iir_params=iir_params, + f_pass=None, + f_stop=None, + sfreq=self._info["sfreq"], + return_copy=False, + phase="forward", + ) @copy_doc(ContainsMixin.get_channel_types) def get_channel_types( From cd736ec4963b20c2d5507c5b914500fc4e1f2c93 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Feb 2024 15:28:29 +0000 Subject: [PATCH 04/69] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne_lsl/stream/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index dfb752440..db9f22ef3 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -395,7 +395,7 @@ def filter( if key in iir_params: del iir_params[key] iir_params["output"] = "sos" - # construt an IIR filter + # construct an IIR filter with use_log_level(logger.level): # ensure MNE log is set to the same level iir_params = construct_iir_filter( iir_params=iir_params, From d29590af93121b25e611d6071e092d72b3b39d9c Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Tue, 27 Feb 2024 16:57:11 +0100 Subject: [PATCH 05/69] fix logging in filter function --- mne_lsl/stream/_base.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index db9f22ef3..449c66e1e 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -29,7 +29,7 @@ from ..utils._checks import check_type, check_value from ..utils._docs import copy_doc, fill_doc -from ..utils.logs import logger +from ..utils.logs import logger, verbose from ..utils.meas_info import _HUMAN_UNITS, _set_channel_units if TYPE_CHECKING: @@ -335,6 +335,7 @@ def drop_channels(self, ch_names: Union[str, list[str], tuple[str]]) -> BaseStre self._pick(picks) return self + @verbose @fill_doc def filter( self, @@ -344,6 +345,8 @@ def filter( iir_params: Optional[dict[str, Any]] = dict( order=4, ftype="butter", output="sos" ), + *, + verbose: Optional[Union[bool, str, int]] = None, ) -> BaseStream: # noqa: A003 """Filter the stream with an IIR causal filter. @@ -372,6 +375,7 @@ def filter( .. note:: The output ``sos`` must be used. The ``ba`` output is not supported. + %(verbose)s Returns ------- @@ -396,7 +400,7 @@ def filter( del iir_params[key] iir_params["output"] = "sos" # construct an IIR filter - with use_log_level(logger.level): # ensure MNE log is set to the same level + with use_log_level(logger.level if verbose is None else verbose): iir_params = construct_iir_filter( iir_params=iir_params, f_pass=None, From a2c409f8361665dabf47f3eab7a3682d14ae282f Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Tue, 27 Feb 2024 17:05:19 +0100 Subject: [PATCH 06/69] fix B006 --- mne_lsl/stream/_base.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index 84a5709db..266c84d64 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -347,9 +347,7 @@ def filter( l_freq: Optional[float], h_freq: Optional[float], picks, - iir_params: Optional[dict[str, Any]] = dict( - order=4, ftype="butter", output="sos" - ), + iir_params: Optional[dict[str, Any]] = None, *, verbose: Optional[Union[bool, str, int]] = None, ) -> BaseStream: # noqa: A003 @@ -390,6 +388,11 @@ def filter( self._check_connected_and_regular_sampling("filter()") # validate the arguments picks = _picks_to_idx(self._info, picks, "all", "bads", allow_empty=False) + iir_params = ( + dict(order=4, ftype="butter", output="sos") + if iir_params is None + else iir_params + ) if ("output" in iir_params and iir_params["output"] != "sos") or all( key in iir_params for key in ("a", "b") ): From 3727e834aa4f897808ce422b76d59e3dc366ff77 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 28 Feb 2024 11:30:38 +0100 Subject: [PATCH 07/69] use create_filter directly and add filter storage --- mne_lsl/stream/_base.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index 266c84d64..b08d8810b 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -10,8 +10,9 @@ import numpy as np from mne import pick_info, pick_types from mne.channels import rename_channels -from mne.filter import construct_iir_filter +from mne.filter import create_filter from mne.utils import check_version, use_log_level +from scipy.signal import sosfilt_zi if check_version("mne", "1.6"): from mne._fiff.constants import FIFF, _ch_unit_mul_named @@ -408,15 +409,22 @@ def filter( del iir_params[key] iir_params["output"] = "sos" # construct an IIR filter - with use_log_level(logger.level if verbose is None else verbose): - iir_params = construct_iir_filter( - iir_params=iir_params, - f_pass=None, - f_stop=None, - sfreq=self._info["sfreq"], - return_copy=False, - phase="forward", - ) + filter_ = create_filter( + data=None, + sfreq=self._info["sfreq"], + l_freq=l_freq, + h_freq=h_freq, + method="iir", + iir_params=iir_params, + phase="forward", + verbose=logger.level if verbose is None else verbose, + ) + filter_["zi"] = None # add initial conditions + filter_["zi_coeff"] = sosfilt_zi(self._sos) + filter_["picks"] = picks + # add filter to the list of applied filters + with self._interrupt_acquisition(): + self._filters.append(filter_) @copy_doc(ContainsMixin.get_channel_types) def get_channel_types( @@ -938,6 +946,7 @@ def _reset_variables(self) -> None: self._added_channels = [] self._ref_channels = None self._ref_from = None + self._filters = [] self._timestamps = None # This method needs to reset any stream-system-specific variables, e.g. an inlet # or a StreamInfo for LSL streams. From 0aaa2909cd47d399e3299468328a4cbf782b938e Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 28 Feb 2024 11:32:28 +0100 Subject: [PATCH 08/69] drop use_log_level --- mne_lsl/stream/_base.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index b08d8810b..c8807a794 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -11,7 +11,7 @@ from mne import pick_info, pick_types from mne.channels import rename_channels from mne.filter import create_filter -from mne.utils import check_version, use_log_level +from mne.utils import check_version from scipy.signal import sosfilt_zi if check_version("mne", "1.6"): @@ -923,8 +923,7 @@ def _pick(self, picks: NDArray[+ScalarIntType]) -> None: ) with self._interrupt_acquisition(): - with use_log_level(logger.level): - self._info = pick_info(self._info, picks) + self._info = pick_info(self._info, picks, verbose=logger.level) self._picks_inlet = self._picks_inlet[picks_inlet] self._buffer = self._buffer[:, picks] From 0682c833bd0c8e8204bf845feb336a6b98ec717e Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 28 Feb 2024 11:37:39 +0100 Subject: [PATCH 09/69] edit filters during pick operation --- mne_lsl/stream/_base.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index c8807a794..5e81a7da8 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -926,11 +926,22 @@ def _pick(self, picks: NDArray[+ScalarIntType]) -> None: self._info = pick_info(self._info, picks, verbose=logger.level) self._picks_inlet = self._picks_inlet[picks_inlet] self._buffer = self._buffer[:, picks] - # prune added channels which are not part of the inlet for ch in self._added_channels[::-1]: if ch not in self.ch_names: self._added_channels.remove(ch) + # remove dropped channels from filters + filters2remove = [] + for k, filter_ in enumerate(self._filters): + filter_["picks"] = np.intersect1d( + filter_["picks"], picks, assume_unique=True + ) + if filter_["picks"].size == 0: + filters2remove.append(k) + continue + filter_["zi"] = None # reset initial conditions + for k in filters2remove[::-1]: + del self._filters[k] @abstractmethod def _reset_variables(self) -> None: From b2a9ecd25e02e793b96f448916900ab63ce9e11b Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 28 Feb 2024 13:55:10 +0100 Subject: [PATCH 10/69] add filter placeholder --- mne_lsl/stream/stream_lsl.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/mne_lsl/stream/stream_lsl.py b/mne_lsl/stream/stream_lsl.py index 2dc38a3de..90534da87 100644 --- a/mne_lsl/stream/stream_lsl.py +++ b/mne_lsl/stream/stream_lsl.py @@ -229,6 +229,11 @@ def _acquire(self) -> None: self._create_acquisition_thread(self._acquisition_delay) return # interrupt early + # select the last self._timestamps.size samples from data and timestamps in + # case more samples than the buffer can hold were retrieved. + data = data[-self._timestamps.size :, :] + timestamps = timestamps[-self._timestamps.size :] + # process acquisition window n_channels = self._inlet.n_channels assert data.ndim == 2 and data.shape[-1] == n_channels, ( @@ -250,6 +255,10 @@ def _acquire(self) -> None: data_ref = data[:, self._ref_channels].mean(axis=1, keepdims=True) data[:, self._ref_from] -= data_ref + # apply filters + for filter_ in self._filters: # noqa + pass + # roll and update buffers self._buffer = np.roll(self._buffer, -timestamps.size, axis=0) self._timestamps = np.roll(self._timestamps, -timestamps.size, axis=0) @@ -260,16 +269,11 @@ def _acquire(self) -> None: n_channels, self._picks_inlet.size, ) - # select the last self._timestamps.size samples from data and timestamps in - # case more samples than the buffer can hold were retrieved. - self._buffer[-timestamps.size :, :] = data[-self._timestamps.size :, :] - self._timestamps[-timestamps.size :] = timestamps[-self._timestamps.size :] + self._buffer[-timestamps.size :, :] = data + self._timestamps[-timestamps.size :] = timestamps # update the number of new samples available - self._n_new_samples += min(timestamps.size, self._timestamps.size) - if ( - self._timestamps.size < self._n_new_samples - or self._timestamps.size < timestamps.size - ): + self._n_new_samples += timestamps.size + if self._timestamps.size < self._n_new_samples: logger.info( "The number of new samples exceeds the buffer size. Consider using " "a larger buffer by creating a Stream with a larger 'bufsize' " From 213f526a15eb5a7f263ae8bd1c6ab1ddae789915 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 28 Feb 2024 16:53:02 +0100 Subject: [PATCH 11/69] add filtering of every buffer --- mne_lsl/stream/_base.py | 2 +- mne_lsl/stream/stream_lsl.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index 5e81a7da8..b20253e71 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -420,7 +420,7 @@ def filter( verbose=logger.level if verbose is None else verbose, ) filter_["zi"] = None # add initial conditions - filter_["zi_coeff"] = sosfilt_zi(self._sos) + filter_["zi_coeff"] = sosfilt_zi(filter_["sos"])[..., np.newaxis] filter_["picks"] = picks # add filter to the list of applied filters with self._interrupt_acquisition(): diff --git a/mne_lsl/stream/stream_lsl.py b/mne_lsl/stream/stream_lsl.py index 90534da87..bafcb40be 100644 --- a/mne_lsl/stream/stream_lsl.py +++ b/mne_lsl/stream/stream_lsl.py @@ -6,6 +6,7 @@ import numpy as np from mne.utils import check_version +from scipy.signal import sosfilt if check_version("mne", "1.5"): from mne.io.constants import FIFF @@ -255,9 +256,17 @@ def _acquire(self) -> None: data_ref = data[:, self._ref_channels].mean(axis=1, keepdims=True) data[:, self._ref_from] -= data_ref - # apply filters - for filter_ in self._filters: # noqa - pass + # apply filters on (n_times, n_channels) data + for filter_ in self._filters: + if filter_["zi"] is None: + # initial conditions are set to a step response steady-state set + # on the mean on the acquisition window (e.g. DC offset for EEGs) + filter_["zi"] = filter_["zi_coeff"] * np.mean( + data[filter_["picks"]], axis=0 + ) + data, filter_["zi"] = sosfilt( + filter_["sos"], data[filter_["picks"]], zi=filter_["zi"], axis=0 + ) # roll and update buffers self._buffer = np.roll(self._buffer, -timestamps.size, axis=0) From aa765d68dd02609e1e60751a790ceb0adbd24150 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Thu, 29 Feb 2024 11:19:30 +0100 Subject: [PATCH 12/69] fix data indexing --- mne_lsl/stream/stream_lsl.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mne_lsl/stream/stream_lsl.py b/mne_lsl/stream/stream_lsl.py index bafcb40be..7d6eb2969 100644 --- a/mne_lsl/stream/stream_lsl.py +++ b/mne_lsl/stream/stream_lsl.py @@ -262,11 +262,12 @@ def _acquire(self) -> None: # initial conditions are set to a step response steady-state set # on the mean on the acquisition window (e.g. DC offset for EEGs) filter_["zi"] = filter_["zi_coeff"] * np.mean( - data[filter_["picks"]], axis=0 + data[:, filter_["picks"]], axis=0 ) - data, filter_["zi"] = sosfilt( - filter_["sos"], data[filter_["picks"]], zi=filter_["zi"], axis=0 + data_, filter_["zi"] = sosfilt( + filter_["sos"], data[:, filter_["picks"]], zi=filter_["zi"], axis=0 ) + data[:, filter_["picks"]] = data_ # roll and update buffers self._buffer = np.roll(self._buffer, -timestamps.size, axis=0) From 49e7d8ea17ba37fd6162c2550b6975db2ba318a6 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Thu, 29 Feb 2024 11:19:50 +0100 Subject: [PATCH 13/69] fix variable name --- mne_lsl/stream/stream_lsl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne_lsl/stream/stream_lsl.py b/mne_lsl/stream/stream_lsl.py index 7d6eb2969..e51530acc 100644 --- a/mne_lsl/stream/stream_lsl.py +++ b/mne_lsl/stream/stream_lsl.py @@ -264,10 +264,10 @@ def _acquire(self) -> None: filter_["zi"] = filter_["zi_coeff"] * np.mean( data[:, filter_["picks"]], axis=0 ) - data_, filter_["zi"] = sosfilt( + data_filtered, filter_["zi"] = sosfilt( filter_["sos"], data[:, filter_["picks"]], zi=filter_["zi"], axis=0 ) - data[:, filter_["picks"]] = data_ + data[:, filter_["picks"]] = data_filtered # roll and update buffers self._buffer = np.roll(self._buffer, -timestamps.size, axis=0) From fc52dd642ae3aba5563e501b73f4a2230c6401ba Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Thu, 29 Feb 2024 11:27:03 +0100 Subject: [PATCH 14/69] improve assertion and array selection --- mne_lsl/stream/stream_lsl.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/mne_lsl/stream/stream_lsl.py b/mne_lsl/stream/stream_lsl.py index e51530acc..79b0d38f0 100644 --- a/mne_lsl/stream/stream_lsl.py +++ b/mne_lsl/stream/stream_lsl.py @@ -230,18 +230,17 @@ def _acquire(self) -> None: self._create_acquisition_thread(self._acquisition_delay) return # interrupt early - # select the last self._timestamps.size samples from data and timestamps in - # case more samples than the buffer can hold were retrieved. - data = data[-self._timestamps.size :, :] - timestamps = timestamps[-self._timestamps.size :] - # process acquisition window n_channels = self._inlet.n_channels assert data.ndim == 2 and data.shape[-1] == n_channels, ( - data.shape, - n_channels, + f"Data shape {data.shape} (n_samples, n_channels) for " + f"{n_channels} channels." ) - data = data[:, self._picks_inlet] # subselect channels + # select the last self._timestamps.size samples from data and timestamps in + # case more samples than the buffer can hold were retrieved. + # select channels retained in the buffer. + data = data[-self._timestamps.size :, self._picks_inlet] + timestamps = timestamps[-self._timestamps.size :] if self._stype == "annotations" and np.count_nonzero(data) == 0: if not self._interrupt: self._create_acquisition_thread(self._acquisition_delay) From 4dfff5256675af14738e3c6e896a79e72ece2d9e Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Thu, 29 Feb 2024 13:20:10 +0100 Subject: [PATCH 15/69] add notes about reset --- mne_lsl/stream/_base.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index b20253e71..a429a00c4 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -322,6 +322,12 @@ def drop_channels(self, ch_names: Union[str, list[str], tuple[str]]) -> BaseStre stream : instance of ``Stream`` The stream instance modified in-place. + Notes + ----- + Dropping channels which are part of a filter will reset the initial conditions + of this filter. The initial conditions will be re-estimated as a step response + steady-state. + See Also -------- pick @@ -385,6 +391,12 @@ def filter( ------- stream : instance of ``Stream`` The stream instance modified in-place. + + Notes + ----- + Adding a filter on channels already filtered will reset the initial conditions + of all channels filtered by the first filter. The initial conditions will be + re-estimated as a step response steady-state. """ self._check_connected_and_regular_sampling("filter()") # validate the arguments @@ -571,6 +583,10 @@ def pick(self, picks, exclude=()) -> BaseStream: Contrary to MNE-Python, re-ordering channels is not supported in ``MNE-LSL``. Thus, if explicit channel names are provided in ``picks``, they are sorted to match the order of existing channel names. + + Dropping channels which are part of a filter will reset the initial conditions + of this filter. The initial conditions will be re-estimated as a step response + steady-state. """ self._check_connected(name="pick()") picks = _picks_to_idx(self._info, picks, "all", exclude, allow_empty=False) From b7745ee5172702f4d0f248f42279304ea3f27d11 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Thu, 29 Feb 2024 13:57:40 +0100 Subject: [PATCH 16/69] add idea to sanitize filter list --- mne_lsl/stream/_base.py | 57 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 54 insertions(+), 3 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index a429a00c4..fdca1b552 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager +from copy import deepcopy from math import ceil from threading import Timer from typing import TYPE_CHECKING @@ -10,7 +11,7 @@ import numpy as np from mne import pick_info, pick_types from mne.channels import rename_channels -from mne.filter import create_filter +from mne.filter import create_filter, estimate_ringing_samples from mne.utils import check_version from scipy.signal import sosfilt_zi @@ -399,7 +400,7 @@ def filter( re-estimated as a step response steady-state. """ self._check_connected_and_regular_sampling("filter()") - # validate the arguments + # validate the arguments and ensure 'sos' output picks = _picks_to_idx(self._info, picks, "all", "bads", allow_empty=False) iir_params = ( dict(order=4, ftype="butter", output="sos") @@ -433,10 +434,17 @@ def filter( ) filter_["zi"] = None # add initial conditions filter_["zi_coeff"] = sosfilt_zi(filter_["sos"])[..., np.newaxis] + # to correctly handle the filter initial conditions even if 2 filters are + # applied to the same channels, we need to separate the 'picks' between filter + # to avoid any channel-overlap between filters. + # if the initial conditions are updated in real-time in the _acquire function, + # we need to update the 'zi' for each individual second order filter in the + # 'sos' output, which does not seem to be supported by scipy directly. filter_["picks"] = picks + filters = _sanitize_filters(self._filters, filter_) # add filter to the list of applied filters with self._interrupt_acquisition(): - self._filters.append(filter_) + self._filters = filters @copy_doc(ContainsMixin.get_channel_types) def get_channel_types( @@ -1055,3 +1063,46 @@ def n_new_samples(self) -> int: """ self._check_connected(name="n_new_samples") return self._n_new_samples + + +def _sanitize_filters( + filters: list[dict[str, Any]], filter_: dict[str, Any] +) -> list[dict[str, Any]]: + """Sanitize the list of filters to ensure non-overlapping channels.""" + filters = deepcopy(filters) + additional_filters = [] + for filt in filters: + intersection = np.intersect1d( + filt["picks"], filter_["picks"], assume_unique=True + ) + if intersection.size == 0: + continue # non-overlapping channels + # create a combined filter + ftype = ( + filt["ftype"] + if filt["ftype"] == filter_["ftype"] + else f"{filt['ftype']}+{filter_['ftype']}" + ) + system = np.vstack((filt["sos"], filter_["sos"])) + combined_filter = { + "order": filter_["sos"].shape[0] + filt["sos"].shape[0], + "ftype": ftype, + "output": "sos", + "padlen": estimate_ringing_samples(system), + "sos": system, + } + combined_filter["zi"] = None + combined_filter["zi_coeff"] = sosfilt_zi(combined_filter["sos"])[ + ..., np.newaxis + ] + combined_filter["picks"] = intersection + additional_filters.append(combined_filter) + # reset initial conditions for the overlapping filter + filt["zi"] = None + # remove overlapping channels from both filters + filt["picks"] = np.setdiff1d(filt["picks"], intersection, assume_unique=True) + filter_["picks"] = np.setdiff1d( + filter_["picks"], intersection, assume_unique=True + ) + additional_filters.append(filter_) + return filters + additional_filters From d92dee1cc1bd56c822084815c1dfcf326ede4ec9 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Thu, 29 Feb 2024 15:29:36 +0100 Subject: [PATCH 17/69] add tests --- mne_lsl/stream/_base.py | 54 +++++++++++++--- mne_lsl/stream/tests/test_base.py | 101 ++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 7 deletions(-) create mode 100644 mne_lsl/stream/tests/test_base.py diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index fdca1b552..0ec909a29 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -434,6 +434,9 @@ def filter( ) filter_["zi"] = None # add initial conditions filter_["zi_coeff"] = sosfilt_zi(filter_["sos"])[..., np.newaxis] + # store requested l_freq and h_freq + filter_["l_freq"] = l_freq + filter_["h_freq"] = h_freq # to correctly handle the filter initial conditions even if 2 filters are # applied to the same channels, we need to separate the 'picks' between filter # to avoid any channel-overlap between filters. @@ -441,7 +444,7 @@ def filter( # we need to update the 'zi' for each individual second order filter in the # 'sos' output, which does not seem to be supported by scipy directly. filter_["picks"] = picks - filters = _sanitize_filters(self._filters, filter_) + filters = _sanitize_filters(self._filters, StreamFilter(filter_)) # add filter to the list of applied filters with self._interrupt_acquisition(): self._filters = filters @@ -1066,7 +1069,7 @@ def n_new_samples(self) -> int: def _sanitize_filters( - filters: list[dict[str, Any]], filter_: dict[str, Any] + filters: list[StreamFilter], filter_: StreamFilter ) -> list[dict[str, Any]]: """Sanitize the list of filters to ensure non-overlapping channels.""" filters = deepcopy(filters) @@ -1084,19 +1087,21 @@ def _sanitize_filters( else f"{filt['ftype']}+{filter_['ftype']}" ) system = np.vstack((filt["sos"], filter_["sos"])) + assert filter_["order"] == 2 * filter_["sos"].shape[0] # sanity-check + assert filt["order"] == 2 * filt["sos"].shape[0] # sanity-check combined_filter = { - "order": filter_["sos"].shape[0] + filt["sos"].shape[0], + "order": 2 * system.shape[0], "ftype": ftype, "output": "sos", "padlen": estimate_ringing_samples(system), "sos": system, } combined_filter["zi"] = None - combined_filter["zi_coeff"] = sosfilt_zi(combined_filter["sos"])[ - ..., np.newaxis - ] + combined_filter["zi_coeff"] = sosfilt_zi(system)[..., np.newaxis] combined_filter["picks"] = intersection - additional_filters.append(combined_filter) + combined_filter["l_freq"] = (filt["l_freq"], filter_["l_freq"]) + combined_filter["h_freq"] = (filt["h_freq"], filter_["h_freq"]) + additional_filters.append(StreamFilter(combined_filter)) # reset initial conditions for the overlapping filter filt["zi"] = None # remove overlapping channels from both filters @@ -1106,3 +1111,38 @@ def _sanitize_filters( ) additional_filters.append(filter_) return filters + additional_filters + + +class StreamFilter(dict): + """Class defining a filter.""" + + def __repr__(self): # noqa: D105 + return f"" + + def __eq__(self, other: Any): + """Equality operator.""" + if not isinstance(other, StreamFilter): + return False + if sorted(self) != sorted(other): + return False + for key in self: + type_ = type(self[key]) + if not isinstance(other[key], type_): # sanity-check + warn( + f"The type of the key '{key}' is different between the 2 filters, " + "which should not be possible. Please contact the developers.", + RuntimeWarning, + stacklevel=2, + ) + return False + if type_ is np.ndarray and not np.array_equal( + self[key], other[key], equal_nan=True + ): + return False + elif type_ is not np.ndarray and self[key] != other[key]: + return False + return True + + def __ne__(self, other: Any): # explicit method required to issue warning + """Inequality operator.""" + return not self.__eq__(other) diff --git a/mne_lsl/stream/tests/test_base.py b/mne_lsl/stream/tests/test_base.py new file mode 100644 index 000000000..85e943516 --- /dev/null +++ b/mne_lsl/stream/tests/test_base.py @@ -0,0 +1,101 @@ +from __future__ import annotations # c.f. PEP 563, PEP 649 + +from copy import deepcopy +from typing import TYPE_CHECKING + +import numpy as np +import pytest +from mne.filter import create_filter +from scipy.signal import sosfilt_zi + +from mne_lsl.stream._base import StreamFilter, _sanitize_filters + +if TYPE_CHECKING: + from typing import Any + + +@pytest.fixture(scope="function") +def filters() -> list[dict[str, Any]]: + """Create a list of valid filters.""" + l_freqs = (1, 1, 0.1) + h_freqs = (40, 15, None) + picks = (np.arange(0, 10), np.arange(10, 20), np.arange(20, 30)) + filters = [ + create_filter( + data=None, + sfreq=1000, + l_freq=lfq, + h_freq=hfq, + method="iir", + iir_params=dict(order=4, ftype="butter", output="sos"), + phase="forward", + verbose="ERROR", + ) + for lfq, hfq in zip(l_freqs, h_freqs, strict=True) + ] + for filt, l_fq, h_fq, pick in zip(filters, l_freqs, h_freqs, picks, strict=True): + filt["zi"] = None + filt["zi_coeff"] = sosfilt_zi(filt["sos"]) + filt["picks"] = pick + filt["l_freq"] = l_fq + filt["h_freq"] = h_fq + all_picks = np.hstack([filt["picks"] for filt in filters]) + assert np.unique(all_picks).size == all_picks.size # sanity-check + return [StreamFilter(filter_) for filter_ in filters] + + +def test_sanitize_filters_no_overlap(filters): + """Test clean-up of filter list to ensure non-overlap between channels.""" + filter_ = create_filter( + data=None, + sfreq=1000, + l_freq=None, + h_freq=100, + method="iir", + iir_params=dict(order=4, ftype="butter", output="sos"), + phase="forward", + verbose="ERROR", + ) + filter_["zi"] = None + filter_["zi_coeff"] = sosfilt_zi(filter_["sos"]) + filter_["picks"] = np.arange(30, 40) + filter_["l_freq"] = None + filter_["h_freq"] = 100 + filter_ = StreamFilter(filter_) + all_picks = np.hstack([filt["picks"] for filt in filters + [filter_]]) + assert np.unique(all_picks).size == all_picks.size + filters_clean = _sanitize_filters(filters, filter_) + assert len(filters) == 3 + assert len(filters_clean) == 4 + assert filters == filters_clean[:3] + assert filters_clean[-1] not in filters + assert filters_clean[-1]["l_freq"] is None + assert filters_clean[-1]["h_freq"] == 100 + assert np.array_equal(filters_clean[-1]["picks"], np.arange(30, 40)) + assert filters_clean[-1]["order"] == 4 + assert filters_clean[-1]["sos"].shape == (2, 6) + + +def test_StreamFilter(filters): + """Test the StreamFilter class.""" + filter2 = deepcopy(filters[0]) + assert filter2 == filters[0] + assert filters[0] != filters[1] + assert filters[0] != filters[2] + # test different key types + filter2["order"] = str(filter2["order"]) # force different type + with pytest.warns(RuntimeWarning, match="type of the key 'order' is different"): + assert filter2 != filters[0] + # test with nans + filter2 = deepcopy(filters[0]) + filter3 = deepcopy(filters[0]) + filter2["sos"][0, 0] = np.nan + assert filter2 != filter3 + filter3["sos"][0, 0] = np.nan + assert filter2 == filter3 + # test absent key + filter2 = deepcopy(filters[0]) + del filter2["sos"] + assert filter2 != filters[0] + # test representation + assert f"({filters[0]['l_freq']}, {filters[0]['h_freq']})" in repr(filters[0]) From 3a04cdf8806b6d802b57bdbc2f00058c7880b6fd Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Thu, 29 Feb 2024 16:16:17 +0100 Subject: [PATCH 18/69] add more tests --- mne_lsl/stream/_base.py | 4 +-- mne_lsl/stream/tests/test_base.py | 46 +++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index 0ec909a29..905efbd7a 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -1087,10 +1087,8 @@ def _sanitize_filters( else f"{filt['ftype']}+{filter_['ftype']}" ) system = np.vstack((filt["sos"], filter_["sos"])) - assert filter_["order"] == 2 * filter_["sos"].shape[0] # sanity-check - assert filt["order"] == 2 * filt["sos"].shape[0] # sanity-check combined_filter = { - "order": 2 * system.shape[0], + "order": filt["order"] + filter_["order"], "ftype": ftype, "output": "sos", "padlen": estimate_ringing_samples(system), diff --git a/mne_lsl/stream/tests/test_base.py b/mne_lsl/stream/tests/test_base.py index 85e943516..3d9c9f3d9 100644 --- a/mne_lsl/stream/tests/test_base.py +++ b/mne_lsl/stream/tests/test_base.py @@ -76,6 +76,52 @@ def test_sanitize_filters_no_overlap(filters): assert filters_clean[-1]["sos"].shape == (2, 6) +def test_sanitize_filters_partial_overlap(filters): + """Test clean-up of filter list to ensure non-overlap between channels.""" + filter_ = create_filter( + data=None, + sfreq=1000, + l_freq=None, + h_freq=100, + method="iir", + iir_params=dict(order=4, ftype="butter", output="sos"), + phase="forward", + verbose="ERROR", + ) + filter_["zi"] = None + filter_["zi_coeff"] = sosfilt_zi(filter_["sos"]) + filter_["picks"] = np.arange(5, 15) + filter_["l_freq"] = None + filter_["h_freq"] = 100 + filter_ = StreamFilter(filter_) + filters_clean = _sanitize_filters(filters, filter_) + assert len(filters) == 3 + assert len(filters_clean) == 5 + + +def test_sanitize_filters_full_overlap(filters): + """Test clean-up of filter list to ensure non-overlap between channels.""" + filter_ = create_filter( + data=None, + sfreq=1000, + l_freq=None, + h_freq=100, + method="iir", + iir_params=dict(order=4, ftype="butter", output="sos"), + phase="forward", + verbose="ERROR", + ) + filter_["zi"] = None + filter_["zi_coeff"] = sosfilt_zi(filter_["sos"]) + filter_["picks"] = np.arange(0, 10) + filter_["l_freq"] = None + filter_["h_freq"] = 100 + filter_ = StreamFilter(filter_) + filters_clean = _sanitize_filters(filters, filter_) + assert len(filters) == 3 + assert len(filters_clean) == 3 + + def test_StreamFilter(filters): """Test the StreamFilter class.""" filter2 = deepcopy(filters[0]) From c80c6f5044b5050cf8411322bd0c82e7f73840a7 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Thu, 29 Feb 2024 16:19:00 +0100 Subject: [PATCH 19/69] simpler comparison --- mne_lsl/stream/_base.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index 905efbd7a..218e2b86d 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -1119,9 +1119,7 @@ def __repr__(self): # noqa: D105 def __eq__(self, other: Any): """Equality operator.""" - if not isinstance(other, StreamFilter): - return False - if sorted(self) != sorted(other): + if not isinstance(other, StreamFilter) or sorted(self) != sorted(other): return False for key in self: type_ = type(self[key]) @@ -1133,11 +1131,10 @@ def __eq__(self, other: Any): stacklevel=2, ) return False - if type_ is np.ndarray and not np.array_equal( - self[key], other[key], equal_nan=True - ): - return False - elif type_ is not np.ndarray and self[key] != other[key]: + if ( + type_ is np.ndarray + and not np.array_equal(self[key], other[key], equal_nan=True) + ) or (type_ is not np.ndarray and self[key] != other[key]): return False return True From 518ea87025a02903cf902f3226e20290ebda03d6 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Thu, 29 Feb 2024 16:36:45 +0100 Subject: [PATCH 20/69] fix tests --- mne_lsl/stream/_base.py | 11 +++++++++-- mne_lsl/stream/tests/test_base.py | 12 ++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index 218e2b86d..b35a6fe43 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -1107,8 +1107,15 @@ def _sanitize_filters( filter_["picks"] = np.setdiff1d( filter_["picks"], intersection, assume_unique=True ) - additional_filters.append(filter_) - return filters + additional_filters + filters = filters + additional_filters + [filter_] + # prune filters without any channels to apply on + filters2remove = [] + for k, filt in enumerate(filters): + if filt["picks"].size == 0: + filters2remove.append(k) + for k in filters2remove[::-1]: + del filters[k] + return filters class StreamFilter(dict): diff --git a/mne_lsl/stream/tests/test_base.py b/mne_lsl/stream/tests/test_base.py index 3d9c9f3d9..8f6399894 100644 --- a/mne_lsl/stream/tests/test_base.py +++ b/mne_lsl/stream/tests/test_base.py @@ -120,6 +120,18 @@ def test_sanitize_filters_full_overlap(filters): filters_clean = _sanitize_filters(filters, filter_) assert len(filters) == 3 assert len(filters_clean) == 3 + assert filters[1:] == filters_clean[:2] # order is not preserved + assert filters[0]["l_freq"] in filters_clean[-1]["l_freq"] + assert filters[0]["h_freq"] in filters_clean[-1]["h_freq"] + assert filter_["l_freq"] in filters_clean[-1]["l_freq"] + assert filter_["h_freq"] in filters_clean[-1]["h_freq"] + assert np.array_equal(filters_clean[-1]["picks"], np.arange(0, 10)) + assert filters_clean[-1]["zi"] is None + assert not np.array_equal(filters_clean[-1]["zi_coeff"], filters[0]["zi_coeff"]) + assert not np.array_equal(filters_clean[-1]["zi_coeff"], filter_["zi_coeff"]) + assert np.array_equal( + np.vstack((filters[0]["sos"], filter_["sos"])), filters_clean[-1]["sos"] + ) def test_StreamFilter(filters): From ec960f7d5df7e708e795af389f2596cb48452874 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Thu, 29 Feb 2024 16:47:05 +0100 Subject: [PATCH 21/69] fix tests --- mne_lsl/stream/_base.py | 6 ++++++ mne_lsl/stream/tests/test_base.py | 23 +++++++++++++++++++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index b35a6fe43..c6392c5f7 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -1129,6 +1129,12 @@ def __eq__(self, other: Any): if not isinstance(other, StreamFilter) or sorted(self) != sorted(other): return False for key in self: + if key == "zi": # special case since it's either a np.ndarray or None + if (self[key] is None or other[key] is None) or not np.array_equal( + self[key], other[key] + ): + return False + continue type_ = type(self[key]) if not isinstance(other[key], type_): # sanity-check warn( diff --git a/mne_lsl/stream/tests/test_base.py b/mne_lsl/stream/tests/test_base.py index 8f6399894..8dcc07acf 100644 --- a/mne_lsl/stream/tests/test_base.py +++ b/mne_lsl/stream/tests/test_base.py @@ -33,9 +33,11 @@ def filters() -> list[dict[str, Any]]: ) for lfq, hfq in zip(l_freqs, h_freqs, strict=True) ] - for filt, l_fq, h_fq, pick in zip(filters, l_freqs, h_freqs, picks, strict=True): - filt["zi"] = None - filt["zi_coeff"] = sosfilt_zi(filt["sos"]) + for k, (filt, l_fq, h_fq, pick) in enumerate( + zip(filters, l_freqs, h_freqs, picks, strict=True) + ): + filt["zi_coeff"] = sosfilt_zi(filt["sos"])[..., np.newaxis] + filt["zi"] = filt["zi_coeff"] * k filt["picks"] = pick filt["l_freq"] = l_fq filt["h_freq"] = h_fq @@ -97,6 +99,17 @@ def test_sanitize_filters_partial_overlap(filters): filters_clean = _sanitize_filters(filters, filter_) assert len(filters) == 3 assert len(filters_clean) == 5 + # filter 0 and 1 are overlapping with filter_, thus we should have 2 new filters at + # the end of the list, and only filter 2 should be preserved. + assert filters[2] == filters_clean[2] + assert filters[0] not in filters_clean + assert filters[1] not in filters_clean + # filter 0 and 1 should be lacking some channels + for k, pick in enumerate((np.arange(0, 5), np.arange(15, 20))): + assert np.array_equal(filters_clean[k]["picks"], pick) + assert np.array_equal(filters_clean[k]["sos"], filters[k]["sos"]) + assert np.array_equal(filters_clean[k]["zi_coeff"], filters[k]["zi_coeff"]) + assert filters_clean[k]["zi"] is None def test_sanitize_filters_full_overlap(filters): @@ -120,7 +133,9 @@ def test_sanitize_filters_full_overlap(filters): filters_clean = _sanitize_filters(filters, filter_) assert len(filters) == 3 assert len(filters_clean) == 3 - assert filters[1:] == filters_clean[:2] # order is not preserved + # filter 0 and filter_ fully overlap, thus filter 0 will be removed and the combined + # filter is added to the end of the list -> order is not preserved. + assert filters[1:] == filters_clean[:2] assert filters[0]["l_freq"] in filters_clean[-1]["l_freq"] assert filters[0]["h_freq"] in filters_clean[-1]["h_freq"] assert filter_["l_freq"] in filters_clean[-1]["l_freq"] From 0d2a91de42e65be629fa4bd6762717e27dbca460 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Thu, 29 Feb 2024 16:51:17 +0100 Subject: [PATCH 22/69] more tests --- mne_lsl/stream/tests/test_base.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/mne_lsl/stream/tests/test_base.py b/mne_lsl/stream/tests/test_base.py index 8dcc07acf..300bd0b9f 100644 --- a/mne_lsl/stream/tests/test_base.py +++ b/mne_lsl/stream/tests/test_base.py @@ -110,6 +110,30 @@ def test_sanitize_filters_partial_overlap(filters): assert np.array_equal(filters_clean[k]["sos"], filters[k]["sos"]) assert np.array_equal(filters_clean[k]["zi_coeff"], filters[k]["zi_coeff"]) assert filters_clean[k]["zi"] is None + # filter 3 should have the intersection with filter 0 and filter 4 with filter 1 + assert np.array_equal(filters_clean[3]["picks"], np.arange(5, 10)) + assert np.array_equal( + filters_clean[3]["sos"], np.vstack((filters[0]["sos"], filter_["sos"])) + ) + assert not np.array_equal(filters_clean[3]["zi_coeff"], filters[0]["zi_coeff"]) + assert not np.array_equal(filters_clean[3]["zi_coeff"], filter_["zi_coeff"]) + assert filters_clean[3]["zi"] is None + assert np.array_equal(filters_clean[4]["picks"], np.arange(10, 15)) + assert np.array_equal( + filters_clean[4]["sos"], np.vstack((filters[1]["sos"], filter_["sos"])) + ) + assert not np.array_equal(filters_clean[4]["zi_coeff"], filters[1]["zi_coeff"]) + assert not np.array_equal(filters_clean[4]["zi_coeff"], filter_["zi_coeff"]) + assert filters_clean[4]["zi"] is None + # check representation on combined filters + assert filters_clean[3]["l_freq"] == (filters[0]["l_freq"], filter_["l_freq"]) + assert filters_clean[3]["h_freq"] == (filters[0]["h_freq"], filter_["h_freq"]) + assert f"({filters[0]['l_freq']}, {filter_['l_freq']})" in repr(filters_clean[3]) + assert f"({filters[0]['h_freq']}, {filter_['h_freq']})" in repr(filters_clean[3]) + assert filters_clean[4]["l_freq"] == (filters[1]["l_freq"], filter_["l_freq"]) + assert filters_clean[4]["h_freq"] == (filters[1]["h_freq"], filter_["h_freq"]) + assert f"({filters[1]['l_freq']}, {filter_['l_freq']})" in repr(filters_clean[4]) + assert f"({filters[1]['h_freq']}, {filter_['h_freq']})" in repr(filters_clean[4]) def test_sanitize_filters_full_overlap(filters): From 21afacdf0437a3d328336cefabc5fbb6a639e3e6 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Fri, 1 Mar 2024 11:11:31 +0100 Subject: [PATCH 23/69] improve definition and handling of filters --- mne_lsl/stream/_base.py | 146 ++++++------------------------- mne_lsl/stream/_filters.py | 162 +++++++++++++++++++++++++++++++++++ mne_lsl/stream/stream_lsl.py | 14 +-- 3 files changed, 195 insertions(+), 127 deletions(-) create mode 100644 mne_lsl/stream/_filters.py diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index c6392c5f7..b694ba6d1 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -11,7 +11,7 @@ import numpy as np from mne import pick_info, pick_types from mne.channels import rename_channels -from mne.filter import create_filter, estimate_ringing_samples +from mne.filter import create_filter from mne.utils import check_version from scipy.signal import sosfilt_zi @@ -33,6 +33,7 @@ from ..utils._docs import copy_doc, fill_doc from ..utils.logs import logger, verbose from ..utils.meas_info import _HUMAN_UNITS, _set_channel_units +from ._filters import StreamFilter, _sanitize_filters if TYPE_CHECKING: from datetime import datetime @@ -323,12 +324,6 @@ def drop_channels(self, ch_names: Union[str, list[str], tuple[str]]) -> BaseStre stream : instance of ``Stream`` The stream instance modified in-place. - Notes - ----- - Dropping channels which are part of a filter will reset the initial conditions - of this filter. The initial conditions will be re-estimated as a step response - steady-state. - See Also -------- pick @@ -396,8 +391,8 @@ def filter( Notes ----- Adding a filter on channels already filtered will reset the initial conditions - of all channels filtered by the first filter. The initial conditions will be - re-estimated as a step response steady-state. + of those channels. The initial conditions will be re-estimated as a step + response steady-state to the combination of both filters. """ self._check_connected_and_regular_sampling("filter()") # validate the arguments and ensure 'sos' output @@ -407,6 +402,7 @@ def filter( if iir_params is None else iir_params ) + check_type(iir_params, (dict,), "iir_params") if ("output" in iir_params and iir_params["output"] != "sos") or all( key in iir_params for key in ("a", "b") ): @@ -422,7 +418,7 @@ def filter( del iir_params[key] iir_params["output"] = "sos" # construct an IIR filter - filter_ = create_filter( + filt = create_filter( data=None, sfreq=self._info["sfreq"], l_freq=l_freq, @@ -432,19 +428,26 @@ def filter( phase="forward", verbose=logger.level if verbose is None else verbose, ) - filter_["zi"] = None # add initial conditions - filter_["zi_coeff"] = sosfilt_zi(filter_["sos"])[..., np.newaxis] - # store requested l_freq and h_freq - filter_["l_freq"] = l_freq - filter_["h_freq"] = h_freq + # store filter parameters and initial conditions + filt.update( + zi=None, + zi_coeff=sosfilt_zi(filt["sos"])[..., np.newaxis], + l_freq=l_freq, + h_freq=h_freq, + iir_params=iir_params, + sfreq=self._info["sfreq"], + picks=picks, + ) + # remove duplicate information + del filt["order"] + del filt["ftype"] # to correctly handle the filter initial conditions even if 2 filters are # applied to the same channels, we need to separate the 'picks' between filter # to avoid any channel-overlap between filters. # if the initial conditions are updated in real-time in the _acquire function, # we need to update the 'zi' for each individual second order filter in the # 'sos' output, which does not seem to be supported by scipy directly. - filter_["picks"] = picks - filters = _sanitize_filters(self._filters, StreamFilter(filter_)) + filters = _sanitize_filters(self._filters, StreamFilter(filt)) # add filter to the list of applied filters with self._interrupt_acquisition(): self._filters = filters @@ -594,10 +597,6 @@ def pick(self, picks, exclude=()) -> BaseStream: Contrary to MNE-Python, re-ordering channels is not supported in ``MNE-LSL``. Thus, if explicit channel names are provided in ``picks``, they are sorted to match the order of existing channel names. - - Dropping channels which are part of a filter will reset the initial conditions - of this filter. The initial conditions will be re-estimated as a step response - steady-state. """ self._check_connected(name="pick()") picks = _picks_to_idx(self._info, picks, "all", exclude, allow_empty=False) @@ -958,17 +957,12 @@ def _pick(self, picks: NDArray[+ScalarIntType]) -> None: if ch not in self.ch_names: self._added_channels.remove(ch) # remove dropped channels from filters - filters2remove = [] - for k, filter_ in enumerate(self._filters): - filter_["picks"] = np.intersect1d( - filter_["picks"], picks, assume_unique=True - ) - if filter_["picks"].size == 0: - filters2remove.append(k) - continue - filter_["zi"] = None # reset initial conditions - for k in filters2remove[::-1]: - del self._filters[k] + for filt in self._filters: + # TODO: ensure correct selection of channels. + filt["picks"] = np.intersect1d(filt["picks"], picks, assume_unique=True) + # TODO: don't reset, select initial conditions. + filt["zi"] = None + self._filters = [filt for filt in self._filters if filt["picks"].size != 0] @abstractmethod def _reset_variables(self) -> None: @@ -1066,91 +1060,3 @@ def n_new_samples(self) -> int: """ self._check_connected(name="n_new_samples") return self._n_new_samples - - -def _sanitize_filters( - filters: list[StreamFilter], filter_: StreamFilter -) -> list[dict[str, Any]]: - """Sanitize the list of filters to ensure non-overlapping channels.""" - filters = deepcopy(filters) - additional_filters = [] - for filt in filters: - intersection = np.intersect1d( - filt["picks"], filter_["picks"], assume_unique=True - ) - if intersection.size == 0: - continue # non-overlapping channels - # create a combined filter - ftype = ( - filt["ftype"] - if filt["ftype"] == filter_["ftype"] - else f"{filt['ftype']}+{filter_['ftype']}" - ) - system = np.vstack((filt["sos"], filter_["sos"])) - combined_filter = { - "order": filt["order"] + filter_["order"], - "ftype": ftype, - "output": "sos", - "padlen": estimate_ringing_samples(system), - "sos": system, - } - combined_filter["zi"] = None - combined_filter["zi_coeff"] = sosfilt_zi(system)[..., np.newaxis] - combined_filter["picks"] = intersection - combined_filter["l_freq"] = (filt["l_freq"], filter_["l_freq"]) - combined_filter["h_freq"] = (filt["h_freq"], filter_["h_freq"]) - additional_filters.append(StreamFilter(combined_filter)) - # reset initial conditions for the overlapping filter - filt["zi"] = None - # remove overlapping channels from both filters - filt["picks"] = np.setdiff1d(filt["picks"], intersection, assume_unique=True) - filter_["picks"] = np.setdiff1d( - filter_["picks"], intersection, assume_unique=True - ) - filters = filters + additional_filters + [filter_] - # prune filters without any channels to apply on - filters2remove = [] - for k, filt in enumerate(filters): - if filt["picks"].size == 0: - filters2remove.append(k) - for k in filters2remove[::-1]: - del filters[k] - return filters - - -class StreamFilter(dict): - """Class defining a filter.""" - - def __repr__(self): # noqa: D105 - return f"" - - def __eq__(self, other: Any): - """Equality operator.""" - if not isinstance(other, StreamFilter) or sorted(self) != sorted(other): - return False - for key in self: - if key == "zi": # special case since it's either a np.ndarray or None - if (self[key] is None or other[key] is None) or not np.array_equal( - self[key], other[key] - ): - return False - continue - type_ = type(self[key]) - if not isinstance(other[key], type_): # sanity-check - warn( - f"The type of the key '{key}' is different between the 2 filters, " - "which should not be possible. Please contact the developers.", - RuntimeWarning, - stacklevel=2, - ) - return False - if ( - type_ is np.ndarray - and not np.array_equal(self[key], other[key], equal_nan=True) - ) or (type_ is not np.ndarray and self[key] != other[key]): - return False - return True - - def __ne__(self, other: Any): # explicit method required to issue warning - """Inequality operator.""" - return not self.__eq__(other) diff --git a/mne_lsl/stream/_filters.py b/mne_lsl/stream/_filters.py new file mode 100644 index 000000000..dc65b26ec --- /dev/null +++ b/mne_lsl/stream/_filters.py @@ -0,0 +1,162 @@ +from __future__ import annotations # c.f. PEP 563, PEP 649 + +from copy import deepcopy +from typing import TYPE_CHECKING +from warnings import warn + +import numpy as np +from mne.filters import estimate_ringing_samples, create_filter +from scipy.signal import sosfilt_zi + +if TYPE_CHECKING: + from typing import Any + + from numpy.typing import NDArray + + from .._typing import ScalarIntType + + +class StreamFilter(dict): + """Class defining a filter.""" + + def __repr__(self): # noqa: D105 + return f"" + + def __eq__(self, other: Any): + """Equality operator.""" + if not isinstance(other, StreamFilter) or sorted(self) != sorted(other): + return False + for key in self: + if key == "zi": # special case since it's either a np.ndarray or None + if (self[key] is None or other[key] is None) or not np.array_equal( + self[key], other[key] + ): + return False + continue + type_ = type(self[key]) + if not isinstance(other[key], type_): # sanity-check + warn( + f"The type of the key '{key}' is different between the 2 filters, " + "which should not be possible. Please contact the developers.", + RuntimeWarning, + stacklevel=2, + ) + return False + if ( + type_ is np.ndarray + and not np.array_equal(self[key], other[key], equal_nan=True) + ) or (type_ is not np.ndarray and self[key] != other[key]): + return False + return True + + def __ne__(self, other: Any): # explicit method required to issue warning + """Inequality operator.""" + return not self.__eq__(other) + + +def _combine_filters( + filter1: StreamFilter, + filter2: StreamFilter, + picks: NDArray[+ScalarIntType], + *, + copy: bool = True, +) -> StreamFilter: + """Combine 2 filters applied on the same set of channels.""" + assert filter1["sfreq"] == filter2["sfreq"] + if copy: + filter1 = deepcopy(filter1) + filter2 = deepcopy(filter2) + system = np.vstack((filter1["sos"], filter2["sos"])) + # for 'l_freq', 'h_freq', 'iir_params' we store the filter(s) settings in ordered + # tuples to keep track of the original settings of individual filters. + for key in ("l_freq", "h_freq", "iir_params"): + filter1[key] = list( + (filter1[key],) if not isinstance(filter1[key], tuple) else filter1[key] + ) + filter2[key] = list( + (filter2[key],) if not isinstance(filter2[key], tuple) else filter2[key] + ) + combined_filter = { + "output": "sos", + "padlen": estimate_ringing_samples(system), + "sos": system, + "zi": None, # reset initial conditions on channels combined + "zi_coeff": sosfilt_zi(system)[..., np.newaxis], + "l_freq": tuple(filter1["l_freq"] + filter2["l_freq"]), + "h_freq": tuple(filter1["h_freq"] + filter2["h_freq"]), + "iir_params": tuple(filter1["iir_params"] + filter2["iir_params"]), + "sfreq": filter1["sfreq"], + "picks": picks, + } + return StreamFilter(combined_filter) + + +def _uncombine_filters(filt: StreamFilter) -> list[StreamFilter]: + """Uncombine a combined filter into its individual components.""" + val = (isinstance(filt[key], tuple) for key in ("l_freq", "h_freq", "iir_params")) + if not all(val) and any(val): + raise RuntimeError( + "The combined filter contains keys 'l_freq', 'h_freq' and 'iir_params' as " + "both tuple and non-tuple, which should not be possible. Please contact " + "the developers." + ) + elif not all(val): + return [filt] + # instead of trying to un-tangled the 'sos' matrix, we simply create a new filter + # for each individual component. + filters = list() + for lfq, hfq, iir_param in zip( + filt["l_freq"], filt["h_freq"], filt["iir_params"], strict=True + ): + filt = create_filter( + data=None, + sfreq=filt["sfreq"], + l_freq=lfq, + h_freq=hfq, + method="iir", + iir_params=iir_param, + phase="forward", + verbose="CRITICAL", # effectively disable logs + ) + filt.update( + zi=None, + zi_coeff=sosfilt_zi(filt["sos"])[..., np.newaxis], + l_freq=lfq, + h_freq=hfq, + iir_params=iir_param, + sfreq=filt["sfreq"], + picks=filt["picks"], + ) + del filt["order"] + del filt["ftype"] + filters.append(StreamFilter(filt)) + return filters + + +def _sanitize_filters( + filters: list[StreamFilter], filter_: StreamFilter, *, copy: bool = True +) -> list[dict[str, Any]]: + """Sanitize the list of filters to ensure non-overlapping channels.""" + filters = deepcopy(filters) if copy else filters + additional_filters = [] + for filt in filters: + intersection = np.intersect1d( + filt["picks"], filter_["picks"], assume_unique=True + ) + if intersection.size == 0: + continue # non-overlapping channels + additional_filters.append(_combine_filters(filt, filter_, picks=intersection)) + # reset initial conditions for the overlapping filter + filt["zi"] = None # TODO: instead of reset, select initial conditions. + # remove overlapping channels from both filters + filt["picks"] = np.setdiff1d(filt["picks"], intersection, assume_unique=True) + filter_["picks"] = np.setdiff1d( + filter_["picks"], intersection, assume_unique=True + ) + # prune filters without any channels + filters = [ + filt + for filt in filters + additional_filters + [filter_] + if filt["picks"].size != 0 + ] + return filters diff --git a/mne_lsl/stream/stream_lsl.py b/mne_lsl/stream/stream_lsl.py index 79b0d38f0..79d83e57c 100644 --- a/mne_lsl/stream/stream_lsl.py +++ b/mne_lsl/stream/stream_lsl.py @@ -256,17 +256,17 @@ def _acquire(self) -> None: data[:, self._ref_from] -= data_ref # apply filters on (n_times, n_channels) data - for filter_ in self._filters: - if filter_["zi"] is None: + for filt in self._filters: + if filt["zi"] is None: # initial conditions are set to a step response steady-state set # on the mean on the acquisition window (e.g. DC offset for EEGs) - filter_["zi"] = filter_["zi_coeff"] * np.mean( - data[:, filter_["picks"]], axis=0 + filt["zi"] = filt["zi_coeff"] * np.mean( + data[:, filt["picks"]], axis=0 ) - data_filtered, filter_["zi"] = sosfilt( - filter_["sos"], data[:, filter_["picks"]], zi=filter_["zi"], axis=0 + data_filtered, filt["zi"] = sosfilt( + filt["sos"], data[:, filt["picks"]], zi=filt["zi"], axis=0 ) - data[:, filter_["picks"]] = data_filtered + data[:, filt["picks"]] = data_filtered # roll and update buffers self._buffer = np.roll(self._buffer, -timestamps.size, axis=0) From 827123f0a9465805aa79cd37750f0cfcfe82d1c4 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Fri, 1 Mar 2024 11:23:08 +0100 Subject: [PATCH 24/69] better --- mne_lsl/stream/_base.py | 2 +- mne_lsl/stream/_filters.py | 2 +- .../tests/{test_base.py => test_filters.py} | 41 ++++++++++++++----- 3 files changed, 32 insertions(+), 13 deletions(-) rename mne_lsl/stream/tests/{test_base.py => test_filters.py} (89%) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index b694ba6d1..e7021e318 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -430,8 +430,8 @@ def filter( ) # store filter parameters and initial conditions filt.update( - zi=None, zi_coeff=sosfilt_zi(filt["sos"])[..., np.newaxis], + zi=None, l_freq=l_freq, h_freq=h_freq, iir_params=iir_params, diff --git a/mne_lsl/stream/_filters.py b/mne_lsl/stream/_filters.py index dc65b26ec..898bb61e5 100644 --- a/mne_lsl/stream/_filters.py +++ b/mne_lsl/stream/_filters.py @@ -119,8 +119,8 @@ def _uncombine_filters(filt: StreamFilter) -> list[StreamFilter]: verbose="CRITICAL", # effectively disable logs ) filt.update( - zi=None, zi_coeff=sosfilt_zi(filt["sos"])[..., np.newaxis], + zi=None, l_freq=lfq, h_freq=hfq, iir_params=iir_param, diff --git a/mne_lsl/stream/tests/test_base.py b/mne_lsl/stream/tests/test_filters.py similarity index 89% rename from mne_lsl/stream/tests/test_base.py rename to mne_lsl/stream/tests/test_filters.py index 300bd0b9f..261a1ed53 100644 --- a/mne_lsl/stream/tests/test_base.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -14,8 +14,20 @@ from typing import Any +@pytest.fixture(scope="module") +def iir_params() -> dict[str, Any]: + """Return a dictionary with valid IIR parameters.""" + return dict(order=4, ftype="butter", output="sos") + + +@pytest.fixture(scope="module") +def sfreq() -> int: + """Return a valid sampling frequency.""" + return 1000 + + @pytest.fixture(scope="function") -def filters() -> list[dict[str, Any]]: +def filters(iir_params, sfreq) -> list[dict[str, Any]]: """Create a list of valid filters.""" l_freqs = (1, 1, 0.1) h_freqs = (40, 15, None) @@ -23,27 +35,34 @@ def filters() -> list[dict[str, Any]]: filters = [ create_filter( data=None, - sfreq=1000, + sfreq=sfreq, l_freq=lfq, h_freq=hfq, method="iir", - iir_params=dict(order=4, ftype="butter", output="sos"), + iir_params=iir_params, phase="forward", - verbose="ERROR", + verbose="CRITICAL", # disable logs ) for lfq, hfq in zip(l_freqs, h_freqs, strict=True) ] - for k, (filt, l_fq, h_fq, pick) in enumerate( + for k, (filt, lfq, hfq, picks_) in enumerate( zip(filters, l_freqs, h_freqs, picks, strict=True) ): - filt["zi_coeff"] = sosfilt_zi(filt["sos"])[..., np.newaxis] - filt["zi"] = filt["zi_coeff"] * k - filt["picks"] = pick - filt["l_freq"] = l_fq - filt["h_freq"] = h_fq + zi_coeff = sosfilt_zi(filt["sos"])[..., np.newaxis] + filt.update( + zi_coeff=zi_coeff, + zi=zi_coeff * k, + l_freq=lfq, + h_freq=hfq, + iir_params=iir_params, + sfreq=sfreq, + picks=picks_, + ) + del filt["order"] + del filt["ftype"] all_picks = np.hstack([filt["picks"] for filt in filters]) assert np.unique(all_picks).size == all_picks.size # sanity-check - return [StreamFilter(filter_) for filter_ in filters] + return [StreamFilter(filt) for filt in filters] def test_sanitize_filters_no_overlap(filters): From 585ed5dd93af36465a9f26e620ee6a54aa2cc756 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Fri, 1 Mar 2024 11:52:32 +0100 Subject: [PATCH 25/69] fix typos --- mne_lsl/stream/_filters.py | 2 +- mne_lsl/stream/tests/test_filters.py | 56 ++++++++++++++-------------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/mne_lsl/stream/_filters.py b/mne_lsl/stream/_filters.py index 898bb61e5..093e470cb 100644 --- a/mne_lsl/stream/_filters.py +++ b/mne_lsl/stream/_filters.py @@ -5,7 +5,7 @@ from warnings import warn import numpy as np -from mne.filters import estimate_ringing_samples, create_filter +from mne.filter import estimate_ringing_samples, create_filter from scipy.signal import sosfilt_zi if TYPE_CHECKING: diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index 261a1ed53..5b35ebb56 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -65,6 +65,31 @@ def filters(iir_params, sfreq) -> list[dict[str, Any]]: return [StreamFilter(filt) for filt in filters] +def test_StreamFilter(filters): + """Test the StreamFilter class.""" + filter2 = deepcopy(filters[0]) + assert filter2 == filters[0] + assert filters[0] != filters[1] + assert filters[0] != filters[2] + # test different key types + filter2["l_freq"] = str(filter2["l_freq"]) # force different type + with pytest.warns(RuntimeWarning, match="type of the key 'l_freq' is different"): + assert filter2 != filters[0] + # test with nans + filter2 = deepcopy(filters[0]) + filter3 = deepcopy(filters[0]) + filter2["sos"][0, 0] = np.nan + assert filter2 != filter3 + filter3["sos"][0, 0] = np.nan + assert filter2 == filter3 + # test absent key + filter2 = deepcopy(filters[0]) + del filter2["sos"] + assert filter2 != filters[0] + # test representation + assert f"({filters[0]['l_freq']}, {filters[0]['h_freq']})" in repr(filters[0]) + + def test_sanitize_filters_no_overlap(filters): """Test clean-up of filter list to ensure non-overlap between channels.""" filter_ = create_filter( @@ -75,7 +100,7 @@ def test_sanitize_filters_no_overlap(filters): method="iir", iir_params=dict(order=4, ftype="butter", output="sos"), phase="forward", - verbose="ERROR", + verbose="CRITICAL", ) filter_["zi"] = None filter_["zi_coeff"] = sosfilt_zi(filter_["sos"]) @@ -107,7 +132,7 @@ def test_sanitize_filters_partial_overlap(filters): method="iir", iir_params=dict(order=4, ftype="butter", output="sos"), phase="forward", - verbose="ERROR", + verbose="CRITICAL", ) filter_["zi"] = None filter_["zi_coeff"] = sosfilt_zi(filter_["sos"]) @@ -165,7 +190,7 @@ def test_sanitize_filters_full_overlap(filters): method="iir", iir_params=dict(order=4, ftype="butter", output="sos"), phase="forward", - verbose="ERROR", + verbose="CRITICAL", ) filter_["zi"] = None filter_["zi_coeff"] = sosfilt_zi(filter_["sos"]) @@ -190,28 +215,3 @@ def test_sanitize_filters_full_overlap(filters): assert np.array_equal( np.vstack((filters[0]["sos"], filter_["sos"])), filters_clean[-1]["sos"] ) - - -def test_StreamFilter(filters): - """Test the StreamFilter class.""" - filter2 = deepcopy(filters[0]) - assert filter2 == filters[0] - assert filters[0] != filters[1] - assert filters[0] != filters[2] - # test different key types - filter2["order"] = str(filter2["order"]) # force different type - with pytest.warns(RuntimeWarning, match="type of the key 'order' is different"): - assert filter2 != filters[0] - # test with nans - filter2 = deepcopy(filters[0]) - filter3 = deepcopy(filters[0]) - filter2["sos"][0, 0] = np.nan - assert filter2 != filter3 - filter3["sos"][0, 0] = np.nan - assert filter2 == filter3 - # test absent key - filter2 = deepcopy(filters[0]) - del filter2["sos"] - assert filter2 != filters[0] - # test representation - assert f"({filters[0]['l_freq']}, {filters[0]['h_freq']})" in repr(filters[0]) From 7f0d1d01c9c510e9b34b95764e3509044f429cb6 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Fri, 1 Mar 2024 15:19:30 +0100 Subject: [PATCH 26/69] add test for (un)combination of filters --- mne_lsl/stream/_filters.py | 20 +- mne_lsl/stream/tests/test_filters.py | 298 ++++++++++++++++++++++++--- 2 files changed, 280 insertions(+), 38 deletions(-) diff --git a/mne_lsl/stream/_filters.py b/mne_lsl/stream/_filters.py index 093e470cb..aa3991179 100644 --- a/mne_lsl/stream/_filters.py +++ b/mne_lsl/stream/_filters.py @@ -28,7 +28,9 @@ def __eq__(self, other: Any): return False for key in self: if key == "zi": # special case since it's either a np.ndarray or None - if (self[key] is None or other[key] is None) or not np.array_equal( + if self[key] is None and other[key] is None: + continue + elif ((self[key] is None) ^ (other[key] is None)) or not np.array_equal( self[key], other[key] ): return False @@ -91,9 +93,11 @@ def _combine_filters( return StreamFilter(combined_filter) -def _uncombine_filters(filt: StreamFilter) -> list[StreamFilter]: +def _uncombine_filters(filter_: StreamFilter) -> list[StreamFilter]: """Uncombine a combined filter into its individual components.""" - val = (isinstance(filt[key], tuple) for key in ("l_freq", "h_freq", "iir_params")) + val = ( + isinstance(filter_[key], tuple) for key in ("l_freq", "h_freq", "iir_params") + ) if not all(val) and any(val): raise RuntimeError( "The combined filter contains keys 'l_freq', 'h_freq' and 'iir_params' as " @@ -101,16 +105,16 @@ def _uncombine_filters(filt: StreamFilter) -> list[StreamFilter]: "the developers." ) elif not all(val): - return [filt] + return [filter_] # instead of trying to un-tangled the 'sos' matrix, we simply create a new filter # for each individual component. filters = list() for lfq, hfq, iir_param in zip( - filt["l_freq"], filt["h_freq"], filt["iir_params"], strict=True + filter_["l_freq"], filter_["h_freq"], filter_["iir_params"] ): filt = create_filter( data=None, - sfreq=filt["sfreq"], + sfreq=filter_["sfreq"], l_freq=lfq, h_freq=hfq, method="iir", @@ -124,8 +128,8 @@ def _uncombine_filters(filt: StreamFilter) -> list[StreamFilter]: l_freq=lfq, h_freq=hfq, iir_params=iir_param, - sfreq=filt["sfreq"], - picks=filt["picks"], + sfreq=filter_["sfreq"], + picks=filter_["picks"], ) del filt["order"] del filt["ftype"] diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index 5b35ebb56..70a27eacf 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -8,11 +8,43 @@ from mne.filter import create_filter from scipy.signal import sosfilt_zi -from mne_lsl.stream._base import StreamFilter, _sanitize_filters +from mne_lsl.stream._filters import ( + StreamFilter, + _combine_filters, + _sanitize_filters, + _uncombine_filters, +) if TYPE_CHECKING: from typing import Any + from numpy.typing import NDArray + + +def test_StreamFilter(filters): + """Test the StreamFilter class.""" + filter2 = deepcopy(filters[0]) + assert filter2 == filters[0] + assert filters[0] != filters[1] + assert filters[0] != filters[2] + # test different key types + filter2["l_freq"] = str(filter2["l_freq"]) # force different type + with pytest.warns(RuntimeWarning, match="type of the key 'l_freq' is different"): + assert filter2 != filters[0] + # test with nans + filter2 = deepcopy(filters[0]) + filter3 = deepcopy(filters[0]) + filter2["sos"][0, 0] = np.nan + assert filter2 != filter3 + filter3["sos"][0, 0] = np.nan + assert filter2 == filter3 + # test absent key + filter2 = deepcopy(filters[0]) + del filter2["sos"] + assert filter2 != filters[0] + # test representation + assert f"({filters[0]['l_freq']}, {filters[0]['h_freq']})" in repr(filters[0]) + @pytest.fixture(scope="module") def iir_params() -> dict[str, Any]: @@ -21,9 +53,190 @@ def iir_params() -> dict[str, Any]: @pytest.fixture(scope="module") -def sfreq() -> int: +def sfreq() -> float: """Return a valid sampling frequency.""" - return 1000 + return 1000.0 + + +@pytest.fixture(scope="module") +def picks() -> NDArray[np.int32]: + """Return a valid selection of channels.""" + return np.arange(0, 10, dtype=np.int32) + + +@pytest.fixture(scope="function") +def filter1( + iir_params: dict[str, Any], sfreq: float, picks: NDArray[np.int32] +) -> StreamFilter: + """Create a filter.""" + l_freq = 1.0 + h_freq = 40.0 + filt = create_filter( + data=None, + sfreq=sfreq, + l_freq=l_freq, + h_freq=h_freq, + method="iir", + iir_params=iir_params, + phase="forward", + verbose="CRITICAL", + ) + filt.update( + zi_coeff=sosfilt_zi(filt["sos"])[..., np.newaxis], + zi=None, + l_freq=l_freq, + h_freq=h_freq, + iir_params=iir_params, + sfreq=sfreq, + picks=picks, + ) + del filt["order"] + del filt["ftype"] + return StreamFilter(filt) + + +@pytest.fixture(scope="function") +def filter2( + iir_params: dict[str, Any], sfreq: float, picks: NDArray[np.int32] +) -> StreamFilter: + """Create a filter.""" + l_freq = 2.0 + h_freq = None + filt = create_filter( + data=None, + sfreq=sfreq, + l_freq=l_freq, + h_freq=h_freq, + method="iir", + iir_params=iir_params, + phase="forward", + verbose="CRITICAL", + ) + filt.update( + zi_coeff=sosfilt_zi(filt["sos"])[..., np.newaxis], + zi=None, + l_freq=l_freq, + h_freq=h_freq, + iir_params=iir_params, + sfreq=sfreq, + picks=picks, + ) + del filt["order"] + del filt["ftype"] + return StreamFilter(filt) + + +@pytest.fixture(scope="function") +def filter3(sfreq: float, picks: NDArray[np.int32]) -> StreamFilter: + """Create a filter.""" + l_freq = None + h_freq = 80.0 + iir_params = dict(order=2, ftype="bessel", output="sos") + filt = create_filter( + data=None, + sfreq=sfreq, + l_freq=l_freq, + h_freq=h_freq, + method="iir", + iir_params=iir_params, + phase="forward", + verbose="CRITICAL", + ) + filt.update( + zi_coeff=sosfilt_zi(filt["sos"])[..., np.newaxis], + zi=None, + l_freq=l_freq, + h_freq=h_freq, + iir_params=iir_params, + sfreq=sfreq, + picks=picks, + ) + del filt["order"] + del filt["ftype"] + return StreamFilter(filt) + + +def test_combine_uncombine_filters(filter1, filter2, filter3, picks): + """Test (un)combinatation of filters.""" + filt = _combine_filters(filter1, filter2, picks) + assert np.array_equal(filt["sos"], np.vstack((filter1["sos"], filter2["sos"]))) + assert filt["sos"].shape[-1] == 6 + assert filt["l_freq"] == (filter1["l_freq"], filter2["l_freq"]) + assert filt["h_freq"] == (filter1["h_freq"], filter2["h_freq"]) + assert not np.array_equal(filt["zi_coeff"], filter1["zi_coeff"]) + assert not np.array_equal(filt["zi_coeff"], filter2["zi_coeff"]) + assert filt["zi"] is None + assert np.array_equal(filt["picks"], filter1["picks"]) + assert np.array_equal(filt["picks"], filter2["picks"]) + assert filt["sfreq"] == filter1["sfreq"] == filter2["sfreq"] + assert filt["iir_params"] == (filter1["iir_params"], filter2["iir_params"]) + filt1, filt2 = _uncombine_filters(filt) + assert filt1 == filter1 + assert filt2 == filter2 + + # add initial conditions + filter2["zi"] = filter2["zi_coeff"] * 5 + assert filt2 != filter2 + filt = _combine_filters(filter1, filter2, picks) + assert np.array_equal(filt["sos"], np.vstack((filter1["sos"], filter2["sos"]))) + assert filt["sos"].shape[-1] == 6 + assert filt["l_freq"] == (filter1["l_freq"], filter2["l_freq"]) + assert filt["h_freq"] == (filter1["h_freq"], filter2["h_freq"]) + assert not np.array_equal(filt["zi_coeff"], filter1["zi_coeff"]) + assert not np.array_equal(filt["zi_coeff"], filter2["zi_coeff"]) + assert filt["zi"] is None + assert np.array_equal(filt["picks"], filter1["picks"]) + assert np.array_equal(filt["picks"], filter2["picks"]) + assert filt["sfreq"] == filter1["sfreq"] == filter2["sfreq"] + assert filt["iir_params"] == (filter1["iir_params"], filter2["iir_params"]) + filt1, filt2 = _uncombine_filters(filt) + assert filt1 == filter1 + assert filt2 != filter2 + filter2["zi"] = None + assert filt2 == filter2 + + # test with different filter type + filt = _combine_filters(filter1, filter3, picks) + assert np.array_equal(filt["sos"], np.vstack((filter1["sos"], filter3["sos"]))) + assert filt["sos"].shape[-1] == 6 + assert filt["l_freq"] == (filter1["l_freq"], filter3["l_freq"]) + assert filt["h_freq"] == (filter1["h_freq"], filter3["h_freq"]) + assert not np.array_equal(filt["zi_coeff"], filter1["zi_coeff"]) + assert not np.array_equal(filt["zi_coeff"], filter3["zi_coeff"]) + assert filt["zi"] is None + assert np.array_equal(filt["picks"], filter1["picks"]) + assert np.array_equal(filt["picks"], filter3["picks"]) + assert filt["sfreq"] == filter1["sfreq"] == filter3["sfreq"] + assert filt["iir_params"] == (filter1["iir_params"], filter3["iir_params"]) + filt1, filt3 = _uncombine_filters(filt) + assert filt1 == filter1 + assert filt3 == filter3 + + # test combination of 3 filters + filt_ = _combine_filters(filt, filter2, picks) + assert np.array_equal( + filt_["sos"], np.vstack((filter1["sos"], filter3["sos"], filter2["sos"])) + ) + assert np.array_equal(filt_["sos"], np.vstack((filt["sos"], filter2["sos"]))) + assert filt_["sos"].shape[-1] == 6 + assert filt_["l_freq"] == (filter1["l_freq"], filter3["l_freq"], filter2["l_freq"]) + assert filt_["h_freq"] == (filter1["h_freq"], filter3["h_freq"], filter2["h_freq"]) + assert not np.array_equal(filt_["zi_coeff"], filt["zi_coeff"]) + assert not np.array_equal(filt_["zi_coeff"], filter1["zi_coeff"]) + assert not np.array_equal(filt_["zi_coeff"], filter2["zi_coeff"]) + assert not np.array_equal(filt_["zi_coeff"], filter3["zi_coeff"]) + assert filt_["zi"] is None + assert np.array_equal(filt_["picks"], picks) + assert filt_["sfreq"] == filter1["sfreq"] == filter2["sfreq"] == filter3["sfreq"] + assert filt_["iir_params"] == ( + filter1["iir_params"], + filter3["iir_params"], + filter2["iir_params"], + ) + filt1, filt3, filt2 = _uncombine_filters(filt_) + assert filt1 == filter1 + assert filt2 == filter2 # zi already set to None + assert filt3 == filter3 @pytest.fixture(scope="function") @@ -43,11 +256,9 @@ def filters(iir_params, sfreq) -> list[dict[str, Any]]: phase="forward", verbose="CRITICAL", # disable logs ) - for lfq, hfq in zip(l_freqs, h_freqs, strict=True) + for lfq, hfq in zip(l_freqs, h_freqs) ] - for k, (filt, lfq, hfq, picks_) in enumerate( - zip(filters, l_freqs, h_freqs, picks, strict=True) - ): + for k, (filt, lfq, hfq, picks_) in enumerate(zip(filters, l_freqs, h_freqs, picks)): zi_coeff = sosfilt_zi(filt["sos"])[..., np.newaxis] filt.update( zi_coeff=zi_coeff, @@ -65,29 +276,56 @@ def filters(iir_params, sfreq) -> list[dict[str, Any]]: return [StreamFilter(filt) for filt in filters] -def test_StreamFilter(filters): - """Test the StreamFilter class.""" - filter2 = deepcopy(filters[0]) - assert filter2 == filters[0] - assert filters[0] != filters[1] - assert filters[0] != filters[2] - # test different key types - filter2["l_freq"] = str(filter2["l_freq"]) # force different type - with pytest.warns(RuntimeWarning, match="type of the key 'l_freq' is different"): - assert filter2 != filters[0] - # test with nans - filter2 = deepcopy(filters[0]) - filter3 = deepcopy(filters[0]) - filter2["sos"][0, 0] = np.nan - assert filter2 != filter3 - filter3["sos"][0, 0] = np.nan - assert filter2 == filter3 - # test absent key - filter2 = deepcopy(filters[0]) - del filter2["sos"] - assert filter2 != filters[0] - # test representation - assert f"({filters[0]['l_freq']}, {filters[0]['h_freq']})" in repr(filters[0]) +@pytest.fixture( + scope="function", + params=[ + (None, 100, np.arange(30, 40)), + (10, 100, np.arange(30, 40)), + (50, None, np.arange(30, 40)), + (None, 100, np.arange(0, 10)), + (10, 100, np.arange(0, 10)), + (50, None, np.arange(0, 10)), + (None, 100, np.arange(5, 15)), + (10, 100, np.arange(5, 15)), + (50, None, np.arange(5, 15)), + (None, 100, np.arange(5, 10)), + (10, 100, np.arange(5, 10)), + (50, None, np.arange(5, 10)), + (None, 100, np.arange(5, 25)), + (10, 100, np.arange(5, 25)), + (50, None, np.arange(5, 25)), + ], +) +def filter_(request, iir_params: dict[str, Any], sfreq: float) -> StreamFilter: + """Create a filter.""" + l_freq, h_freq, picks = request.param + filt = create_filter( + data=None, + sfreq=sfreq, + l_freq=l_freq, + h_freq=h_freq, + method="iir", + iir_params=iir_params, + phase="forward", + verbose="CRITICAL", + ) + filt.update( + zi_coeff=sosfilt_zi(filt["sos"])[..., np.newaxis], + zi=None, + l_freq=l_freq, + h_freq=h_freq, + iir_params=iir_params, + sfreq=sfreq, + picks=picks, + ) + del filt["order"] + del filt["ftype"] + return StreamFilter(filt) + + +def test_sanitize_filters(filters, filter_): + """Test clean-up of filter list to ensure non-overlap between channels.""" + pass def test_sanitize_filters_no_overlap(filters): From aafc5f4934f327a12d9fe3d1f574953bea1ccda6 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Fri, 1 Mar 2024 15:21:06 +0100 Subject: [PATCH 27/69] rm bad sanitize_filters tests --- mne_lsl/stream/tests/test_filters.py | 127 --------------------------- 1 file changed, 127 deletions(-) diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index 70a27eacf..1873fbc75 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -326,130 +326,3 @@ def filter_(request, iir_params: dict[str, Any], sfreq: float) -> StreamFilter: def test_sanitize_filters(filters, filter_): """Test clean-up of filter list to ensure non-overlap between channels.""" pass - - -def test_sanitize_filters_no_overlap(filters): - """Test clean-up of filter list to ensure non-overlap between channels.""" - filter_ = create_filter( - data=None, - sfreq=1000, - l_freq=None, - h_freq=100, - method="iir", - iir_params=dict(order=4, ftype="butter", output="sos"), - phase="forward", - verbose="CRITICAL", - ) - filter_["zi"] = None - filter_["zi_coeff"] = sosfilt_zi(filter_["sos"]) - filter_["picks"] = np.arange(30, 40) - filter_["l_freq"] = None - filter_["h_freq"] = 100 - filter_ = StreamFilter(filter_) - all_picks = np.hstack([filt["picks"] for filt in filters + [filter_]]) - assert np.unique(all_picks).size == all_picks.size - filters_clean = _sanitize_filters(filters, filter_) - assert len(filters) == 3 - assert len(filters_clean) == 4 - assert filters == filters_clean[:3] - assert filters_clean[-1] not in filters - assert filters_clean[-1]["l_freq"] is None - assert filters_clean[-1]["h_freq"] == 100 - assert np.array_equal(filters_clean[-1]["picks"], np.arange(30, 40)) - assert filters_clean[-1]["order"] == 4 - assert filters_clean[-1]["sos"].shape == (2, 6) - - -def test_sanitize_filters_partial_overlap(filters): - """Test clean-up of filter list to ensure non-overlap between channels.""" - filter_ = create_filter( - data=None, - sfreq=1000, - l_freq=None, - h_freq=100, - method="iir", - iir_params=dict(order=4, ftype="butter", output="sos"), - phase="forward", - verbose="CRITICAL", - ) - filter_["zi"] = None - filter_["zi_coeff"] = sosfilt_zi(filter_["sos"]) - filter_["picks"] = np.arange(5, 15) - filter_["l_freq"] = None - filter_["h_freq"] = 100 - filter_ = StreamFilter(filter_) - filters_clean = _sanitize_filters(filters, filter_) - assert len(filters) == 3 - assert len(filters_clean) == 5 - # filter 0 and 1 are overlapping with filter_, thus we should have 2 new filters at - # the end of the list, and only filter 2 should be preserved. - assert filters[2] == filters_clean[2] - assert filters[0] not in filters_clean - assert filters[1] not in filters_clean - # filter 0 and 1 should be lacking some channels - for k, pick in enumerate((np.arange(0, 5), np.arange(15, 20))): - assert np.array_equal(filters_clean[k]["picks"], pick) - assert np.array_equal(filters_clean[k]["sos"], filters[k]["sos"]) - assert np.array_equal(filters_clean[k]["zi_coeff"], filters[k]["zi_coeff"]) - assert filters_clean[k]["zi"] is None - # filter 3 should have the intersection with filter 0 and filter 4 with filter 1 - assert np.array_equal(filters_clean[3]["picks"], np.arange(5, 10)) - assert np.array_equal( - filters_clean[3]["sos"], np.vstack((filters[0]["sos"], filter_["sos"])) - ) - assert not np.array_equal(filters_clean[3]["zi_coeff"], filters[0]["zi_coeff"]) - assert not np.array_equal(filters_clean[3]["zi_coeff"], filter_["zi_coeff"]) - assert filters_clean[3]["zi"] is None - assert np.array_equal(filters_clean[4]["picks"], np.arange(10, 15)) - assert np.array_equal( - filters_clean[4]["sos"], np.vstack((filters[1]["sos"], filter_["sos"])) - ) - assert not np.array_equal(filters_clean[4]["zi_coeff"], filters[1]["zi_coeff"]) - assert not np.array_equal(filters_clean[4]["zi_coeff"], filter_["zi_coeff"]) - assert filters_clean[4]["zi"] is None - # check representation on combined filters - assert filters_clean[3]["l_freq"] == (filters[0]["l_freq"], filter_["l_freq"]) - assert filters_clean[3]["h_freq"] == (filters[0]["h_freq"], filter_["h_freq"]) - assert f"({filters[0]['l_freq']}, {filter_['l_freq']})" in repr(filters_clean[3]) - assert f"({filters[0]['h_freq']}, {filter_['h_freq']})" in repr(filters_clean[3]) - assert filters_clean[4]["l_freq"] == (filters[1]["l_freq"], filter_["l_freq"]) - assert filters_clean[4]["h_freq"] == (filters[1]["h_freq"], filter_["h_freq"]) - assert f"({filters[1]['l_freq']}, {filter_['l_freq']})" in repr(filters_clean[4]) - assert f"({filters[1]['h_freq']}, {filter_['h_freq']})" in repr(filters_clean[4]) - - -def test_sanitize_filters_full_overlap(filters): - """Test clean-up of filter list to ensure non-overlap between channels.""" - filter_ = create_filter( - data=None, - sfreq=1000, - l_freq=None, - h_freq=100, - method="iir", - iir_params=dict(order=4, ftype="butter", output="sos"), - phase="forward", - verbose="CRITICAL", - ) - filter_["zi"] = None - filter_["zi_coeff"] = sosfilt_zi(filter_["sos"]) - filter_["picks"] = np.arange(0, 10) - filter_["l_freq"] = None - filter_["h_freq"] = 100 - filter_ = StreamFilter(filter_) - filters_clean = _sanitize_filters(filters, filter_) - assert len(filters) == 3 - assert len(filters_clean) == 3 - # filter 0 and filter_ fully overlap, thus filter 0 will be removed and the combined - # filter is added to the end of the list -> order is not preserved. - assert filters[1:] == filters_clean[:2] - assert filters[0]["l_freq"] in filters_clean[-1]["l_freq"] - assert filters[0]["h_freq"] in filters_clean[-1]["h_freq"] - assert filter_["l_freq"] in filters_clean[-1]["l_freq"] - assert filter_["h_freq"] in filters_clean[-1]["h_freq"] - assert np.array_equal(filters_clean[-1]["picks"], np.arange(0, 10)) - assert filters_clean[-1]["zi"] is None - assert not np.array_equal(filters_clean[-1]["zi_coeff"], filters[0]["zi_coeff"]) - assert not np.array_equal(filters_clean[-1]["zi_coeff"], filter_["zi_coeff"]) - assert np.array_equal( - np.vstack((filters[0]["sos"], filter_["sos"])), filters_clean[-1]["sos"] - ) From 9ff3e5513e8373479333d0863fc2b76e5500ab65 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Fri, 1 Mar 2024 15:23:59 +0100 Subject: [PATCH 28/69] improve type-hints --- mne_lsl/stream/tests/test_filters.py | 71 +++++++++++++++------------- 1 file changed, 38 insertions(+), 33 deletions(-) diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index 1873fbc75..d74c33e21 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -21,31 +21,6 @@ from numpy.typing import NDArray -def test_StreamFilter(filters): - """Test the StreamFilter class.""" - filter2 = deepcopy(filters[0]) - assert filter2 == filters[0] - assert filters[0] != filters[1] - assert filters[0] != filters[2] - # test different key types - filter2["l_freq"] = str(filter2["l_freq"]) # force different type - with pytest.warns(RuntimeWarning, match="type of the key 'l_freq' is different"): - assert filter2 != filters[0] - # test with nans - filter2 = deepcopy(filters[0]) - filter3 = deepcopy(filters[0]) - filter2["sos"][0, 0] = np.nan - assert filter2 != filter3 - filter3["sos"][0, 0] = np.nan - assert filter2 == filter3 - # test absent key - filter2 = deepcopy(filters[0]) - del filter2["sos"] - assert filter2 != filters[0] - # test representation - assert f"({filters[0]['l_freq']}, {filters[0]['h_freq']})" in repr(filters[0]) - - @pytest.fixture(scope="module") def iir_params() -> dict[str, Any]: """Return a dictionary with valid IIR parameters.""" @@ -59,14 +34,14 @@ def sfreq() -> float: @pytest.fixture(scope="module") -def picks() -> NDArray[np.int32]: +def picks() -> NDArray[np.int8]: """Return a valid selection of channels.""" - return np.arange(0, 10, dtype=np.int32) + return np.arange(0, 10, dtype=np.int8) @pytest.fixture(scope="function") def filter1( - iir_params: dict[str, Any], sfreq: float, picks: NDArray[np.int32] + iir_params: dict[str, Any], sfreq: float, picks: NDArray[np.int8] ) -> StreamFilter: """Create a filter.""" l_freq = 1.0 @@ -97,7 +72,7 @@ def filter1( @pytest.fixture(scope="function") def filter2( - iir_params: dict[str, Any], sfreq: float, picks: NDArray[np.int32] + iir_params: dict[str, Any], sfreq: float, picks: NDArray[np.int8] ) -> StreamFilter: """Create a filter.""" l_freq = 2.0 @@ -127,7 +102,7 @@ def filter2( @pytest.fixture(scope="function") -def filter3(sfreq: float, picks: NDArray[np.int32]) -> StreamFilter: +def filter3(sfreq: float, picks: NDArray[np.int8]) -> StreamFilter: """Create a filter.""" l_freq = None h_freq = 80.0 @@ -156,7 +131,12 @@ def filter3(sfreq: float, picks: NDArray[np.int32]) -> StreamFilter: return StreamFilter(filt) -def test_combine_uncombine_filters(filter1, filter2, filter3, picks): +def test_combine_uncombine_filters( + filter1: StreamFilter, + filter2: StreamFilter, + filter3: StreamFilter, + picks: NDArray[np.int8], +): """Test (un)combinatation of filters.""" filt = _combine_filters(filter1, filter2, picks) assert np.array_equal(filt["sos"], np.vstack((filter1["sos"], filter2["sos"]))) @@ -240,7 +220,7 @@ def test_combine_uncombine_filters(filter1, filter2, filter3, picks): @pytest.fixture(scope="function") -def filters(iir_params, sfreq) -> list[dict[str, Any]]: +def filters(iir_params: dict[str, Any], sfreq: float) -> list[StreamFilter]: """Create a list of valid filters.""" l_freqs = (1, 1, 0.1) h_freqs = (40, 15, None) @@ -276,6 +256,31 @@ def filters(iir_params, sfreq) -> list[dict[str, Any]]: return [StreamFilter(filt) for filt in filters] +def test_StreamFilter(filters: StreamFilter): + """Test the StreamFilter class.""" + filter2 = deepcopy(filters[0]) + assert filter2 == filters[0] + assert filters[0] != filters[1] + assert filters[0] != filters[2] + # test different key types + filter2["l_freq"] = str(filter2["l_freq"]) # force different type + with pytest.warns(RuntimeWarning, match="type of the key 'l_freq' is different"): + assert filter2 != filters[0] + # test with nans + filter2 = deepcopy(filters[0]) + filter3 = deepcopy(filters[0]) + filter2["sos"][0, 0] = np.nan + assert filter2 != filter3 + filter3["sos"][0, 0] = np.nan + assert filter2 == filter3 + # test absent key + filter2 = deepcopy(filters[0]) + del filter2["sos"] + assert filter2 != filters[0] + # test representation + assert f"({filters[0]['l_freq']}, {filters[0]['h_freq']})" in repr(filters[0]) + + @pytest.fixture( scope="function", params=[ @@ -323,6 +328,6 @@ def filter_(request, iir_params: dict[str, Any], sfreq: float) -> StreamFilter: return StreamFilter(filt) -def test_sanitize_filters(filters, filter_): +def test_sanitize_filters(filters: list[StreamFilter], filter_: StreamFilter): """Test clean-up of filter list to ensure non-overlap between channels.""" pass From 10eef6055acd8e56f8f955d760cc2ddf449e0015 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Fri, 1 Mar 2024 15:27:53 +0100 Subject: [PATCH 29/69] trigger cis From 540fa453102690a437e10cb5bc93bd9fbd2cd0f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Mar 2024 14:33:04 +0000 Subject: [PATCH 30/69] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne_lsl/stream/_base.py | 1 - mne_lsl/stream/_filters.py | 2 +- mne_lsl/stream/tests/test_filters.py | 3 +-- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index e7021e318..167569931 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod from contextlib import contextmanager -from copy import deepcopy from math import ceil from threading import Timer from typing import TYPE_CHECKING diff --git a/mne_lsl/stream/_filters.py b/mne_lsl/stream/_filters.py index aa3991179..f0bcce613 100644 --- a/mne_lsl/stream/_filters.py +++ b/mne_lsl/stream/_filters.py @@ -5,7 +5,7 @@ from warnings import warn import numpy as np -from mne.filter import estimate_ringing_samples, create_filter +from mne.filter import create_filter, estimate_ringing_samples from scipy.signal import sosfilt_zi if TYPE_CHECKING: diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index d74c33e21..97769687e 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -11,7 +11,6 @@ from mne_lsl.stream._filters import ( StreamFilter, _combine_filters, - _sanitize_filters, _uncombine_filters, ) @@ -137,7 +136,7 @@ def test_combine_uncombine_filters( filter3: StreamFilter, picks: NDArray[np.int8], ): - """Test (un)combinatation of filters.""" + """Test (un)combination of filters.""" filt = _combine_filters(filter1, filter2, picks) assert np.array_equal(filt["sos"], np.vstack((filter1["sos"], filter2["sos"]))) assert filt["sos"].shape[-1] == 6 From 26526bb7227b9f3f5d6b6026e96c6747d08cc944 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Fri, 1 Mar 2024 15:44:31 +0100 Subject: [PATCH 31/69] add tests and fix typos --- mne_lsl/stream/_filters.py | 6 +++--- mne_lsl/stream/tests/test_filters.py | 21 ++++++++++++++++++++- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/mne_lsl/stream/_filters.py b/mne_lsl/stream/_filters.py index f0bcce613..accfce3c8 100644 --- a/mne_lsl/stream/_filters.py +++ b/mne_lsl/stream/_filters.py @@ -95,16 +95,16 @@ def _combine_filters( def _uncombine_filters(filter_: StreamFilter) -> list[StreamFilter]: """Uncombine a combined filter into its individual components.""" - val = ( + val = [ isinstance(filter_[key], tuple) for key in ("l_freq", "h_freq", "iir_params") - ) + ] if not all(val) and any(val): raise RuntimeError( "The combined filter contains keys 'l_freq', 'h_freq' and 'iir_params' as " "both tuple and non-tuple, which should not be possible. Please contact " "the developers." ) - elif not all(val): + elif all(elt is False for elt in val): return [filter_] # instead of trying to un-tangled the 'sos' matrix, we simply create a new filter # for each individual component. diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index 97769687e..62ba79c25 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -137,6 +137,11 @@ def test_combine_uncombine_filters( picks: NDArray[np.int8], ): """Test (un)combination of filters.""" + # uncombine self + filt = _uncombine_filters(filter1) + assert filter1 == filt[0] + + # combine 2 filters filt = _combine_filters(filter1, filter2, picks) assert np.array_equal(filt["sos"], np.vstack((filter1["sos"], filter2["sos"]))) assert filt["sos"].shape[-1] == 6 @@ -218,6 +223,20 @@ def test_combine_uncombine_filters( assert filt3 == filter3 +def test_invalid_uncombine_filters(filter1, filter2, picks): + """Test error raising in uncombine filters.""" + filt = _combine_filters(filter1, filter2, picks) + filt["l_freq"] = filt["l_freq"][0] + with pytest.raises(RuntimeError, match="as both tuple and non-tuple"): + _uncombine_filters(filt) + filt["h_freq"] = filt["h_freq"][0] + with pytest.raises(RuntimeError, match="as both tuple and non-tuple"): + _uncombine_filters(filt) + filt["iir_params"] = filt["iir_params"][0] + filt2 = _uncombine_filters(filt) + assert filt == filt2[0] + + @pytest.fixture(scope="function") def filters(iir_params: dict[str, Any], sfreq: float) -> list[StreamFilter]: """Create a list of valid filters.""" @@ -329,4 +348,4 @@ def filter_(request, iir_params: dict[str, Any], sfreq: float) -> StreamFilter: def test_sanitize_filters(filters: list[StreamFilter], filter_: StreamFilter): """Test clean-up of filter list to ensure non-overlap between channels.""" - pass + filters_ = _sanitize_filters(filters, filter_) From 07c01a92750f4c22740d1c5eaa8e95aa143186fb Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Fri, 1 Mar 2024 15:45:02 +0100 Subject: [PATCH 32/69] fix style --- mne_lsl/stream/tests/test_filters.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index 62ba79c25..7b67ce09a 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -8,11 +8,7 @@ from mne.filter import create_filter from scipy.signal import sosfilt_zi -from mne_lsl.stream._filters import ( - StreamFilter, - _combine_filters, - _uncombine_filters, -) +from mne_lsl.stream._filters import StreamFilter, _combine_filters, _uncombine_filters if TYPE_CHECKING: from typing import Any From 6700fec2b1177af9a8ed03bbc5580b626b3f0589 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Fri, 1 Mar 2024 15:45:46 +0100 Subject: [PATCH 33/69] fix imports [ci skip] --- mne_lsl/stream/tests/test_filters.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index 7b67ce09a..38e7814db 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -8,7 +8,12 @@ from mne.filter import create_filter from scipy.signal import sosfilt_zi -from mne_lsl.stream._filters import StreamFilter, _combine_filters, _uncombine_filters +from mne_lsl.stream._filters import ( + StreamFilter, + _combine_filters, + _sanitize_filters, + _uncombine_filters, +) if TYPE_CHECKING: from typing import Any From eb5464f216fcda156bfa25a3bc114de8c8c09cd7 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Fri, 1 Mar 2024 15:49:40 +0100 Subject: [PATCH 34/69] add tests --- mne_lsl/stream/tests/test_filters.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index 38e7814db..decefedaf 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -349,4 +349,9 @@ def filter_(request, iir_params: dict[str, Any], sfreq: float) -> StreamFilter: def test_sanitize_filters(filters: list[StreamFilter], filter_: StreamFilter): """Test clean-up of filter list to ensure non-overlap between channels.""" - filters_ = _sanitize_filters(filters, filter_) + # look for overlapping channels + overlap = [np.intersect1d(filt["picks"], filter_["picks"]) for filt in filters] + # sanitize and validate output + if all(ol.size == 0 for ol in overlap): + filts = _sanitize_filters(filters, filter_) + assert filts == filters + [filter_] From 5c659f9932072831d2b3070f446fb0030d6265fb Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Fri, 1 Mar 2024 16:02:59 +0100 Subject: [PATCH 35/69] more tests [ci skip] --- mne_lsl/stream/_filters.py | 19 +++++++++++-------- mne_lsl/stream/tests/test_filters.py | 7 ++++++- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/mne_lsl/stream/_filters.py b/mne_lsl/stream/_filters.py index accfce3c8..9e4757fad 100644 --- a/mne_lsl/stream/_filters.py +++ b/mne_lsl/stream/_filters.py @@ -60,14 +60,13 @@ def _combine_filters( filter1: StreamFilter, filter2: StreamFilter, picks: NDArray[+ScalarIntType], - *, - copy: bool = True, ) -> StreamFilter: """Combine 2 filters applied on the same set of channels.""" assert filter1["sfreq"] == filter2["sfreq"] - if copy: - filter1 = deepcopy(filter1) - filter2 = deepcopy(filter2) + # copy is required else we might end-up modifying the items of the filters used in + # the acquisition thread. + filter1 = deepcopy(filter1) + filter2 = deepcopy(filter2) system = np.vstack((filter1["sos"], filter2["sos"])) # for 'l_freq', 'h_freq', 'iir_params' we store the filter(s) settings in ordered # tuples to keep track of the original settings of individual filters. @@ -98,7 +97,7 @@ def _uncombine_filters(filter_: StreamFilter) -> list[StreamFilter]: val = [ isinstance(filter_[key], tuple) for key in ("l_freq", "h_freq", "iir_params") ] - if not all(val) and any(val): + if not all(val) and any(val): # sanity-check raise RuntimeError( "The combined filter contains keys 'l_freq', 'h_freq' and 'iir_params' as " "both tuple and non-tuple, which should not be possible. Please contact " @@ -138,10 +137,14 @@ def _uncombine_filters(filter_: StreamFilter) -> list[StreamFilter]: def _sanitize_filters( - filters: list[StreamFilter], filter_: StreamFilter, *, copy: bool = True + filters: list[StreamFilter], + filter_: StreamFilter, ) -> list[dict[str, Any]]: """Sanitize the list of filters to ensure non-overlapping channels.""" - filters = deepcopy(filters) if copy else filters + # copy is required else we might end-up modifying the 'picks' item of the filter + # list used in the acquisition thread. + filters = deepcopy(filters) + filter_ = deepcopy(filter_) additional_filters = [] for filt in filters: intersection = np.intersect1d( diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index decefedaf..fdfb1421d 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -1,6 +1,7 @@ from __future__ import annotations # c.f. PEP 563, PEP 649 from copy import deepcopy +from itertools import chain from typing import TYPE_CHECKING import numpy as np @@ -352,6 +353,10 @@ def test_sanitize_filters(filters: list[StreamFilter], filter_: StreamFilter): # look for overlapping channels overlap = [np.intersect1d(filt["picks"], filter_["picks"]) for filt in filters] # sanitize and validate output + filts = _sanitize_filters(filters, filter_) if all(ol.size == 0 for ol in overlap): - filts = _sanitize_filters(filters, filter_) assert filts == filters + [filter_] + picks = list(chain(*(filt["picks"] for filt in filts))) + assert np.unique(picks).size == len(picks) # ensure no more overlap + # find pairs of filters that have been combined + idx = [i for i, ol in enumerate(overlap) if ol.size != 0] From cb1a2c5113cd85c16b88350ca762e4f2f01381f3 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Fri, 1 Mar 2024 16:08:01 +0100 Subject: [PATCH 36/69] more tests --- mne_lsl/stream/tests/test_filters.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index fdfb1421d..e4b67c35a 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -1,6 +1,7 @@ from __future__ import annotations # c.f. PEP 563, PEP 649 from copy import deepcopy +from distutils.filelist import FileList from itertools import chain from typing import TYPE_CHECKING @@ -356,7 +357,13 @@ def test_sanitize_filters(filters: list[StreamFilter], filter_: StreamFilter): filts = _sanitize_filters(filters, filter_) if all(ol.size == 0 for ol in overlap): assert filts == filters + [filter_] - picks = list(chain(*(filt["picks"] for filt in filts))) - assert np.unique(picks).size == len(picks) # ensure no more overlap - # find pairs of filters that have been combined - idx = [i for i, ol in enumerate(overlap) if ol.size != 0] + else: + picks = list(chain(*(filt["picks"] for filt in filts))) + assert np.unique(picks).size == len(picks) # ensure no more overlap + # find pairs of filters that have been combined + idx = [k for k, ol in enumerate(overlap) if ol.size != 0] + for k in idx: + filt = _combine_filters(filters[k], filter_, overlap[k]) + assert filt in filts + assert filters[k] not in filts + assert filter_ not in filts From b1d6dce23708e6a9aa9389c747faf6bb917cc5b9 Mon Sep 17 00:00:00 2001 From: mscheltienne Date: Fri, 1 Mar 2024 16:08:32 +0100 Subject: [PATCH 37/69] fix import --- mne_lsl/stream/tests/test_filters.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index e4b67c35a..7d95d8c3e 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -1,7 +1,6 @@ from __future__ import annotations # c.f. PEP 563, PEP 649 from copy import deepcopy -from distutils.filelist import FileList from itertools import chain from typing import TYPE_CHECKING From 6f6aff0b7a71bb14d64265a1220b694e22a972dc Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Fri, 1 Mar 2024 17:54:49 +0100 Subject: [PATCH 38/69] better comparison Co-authored-by: Eric Larson --- mne_lsl/stream/_filters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_lsl/stream/_filters.py b/mne_lsl/stream/_filters.py index 9e4757fad..5552acd2c 100644 --- a/mne_lsl/stream/_filters.py +++ b/mne_lsl/stream/_filters.py @@ -24,7 +24,7 @@ def __repr__(self): # noqa: D105 def __eq__(self, other: Any): """Equality operator.""" - if not isinstance(other, StreamFilter) or sorted(self) != sorted(other): + if not isinstance(other, StreamFilter) or set(self) != set(other): return False for key in self: if key == "zi": # special case since it's either a np.ndarray or None From b2bbd1ad42796eb825c0ef309dfdeb524151085b Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 4 Mar 2024 16:48:30 +0100 Subject: [PATCH 39/69] re-simplify --- mne_lsl/stream/_base.py | 11 +- mne_lsl/stream/_filters.py | 134 ++----------- mne_lsl/stream/stream_lsl.py | 2 +- mne_lsl/stream/tests/test_filters.py | 282 +-------------------------- 4 files changed, 17 insertions(+), 412 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index 167569931..c6f42f6b9 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -32,7 +32,7 @@ from ..utils._docs import copy_doc, fill_doc from ..utils.logs import logger, verbose from ..utils.meas_info import _HUMAN_UNITS, _set_channel_units -from ._filters import StreamFilter, _sanitize_filters +from ._filters import StreamFilter if TYPE_CHECKING: from datetime import datetime @@ -440,16 +440,9 @@ def filter( # remove duplicate information del filt["order"] del filt["ftype"] - # to correctly handle the filter initial conditions even if 2 filters are - # applied to the same channels, we need to separate the 'picks' between filter - # to avoid any channel-overlap between filters. - # if the initial conditions are updated in real-time in the _acquire function, - # we need to update the 'zi' for each individual second order filter in the - # 'sos' output, which does not seem to be supported by scipy directly. - filters = _sanitize_filters(self._filters, StreamFilter(filt)) # add filter to the list of applied filters with self._interrupt_acquisition(): - self._filters = filters + self._filters.append(StreamFilter(filt)) @copy_doc(ContainsMixin.get_channel_types) def get_channel_types( diff --git a/mne_lsl/stream/_filters.py b/mne_lsl/stream/_filters.py index 5552acd2c..81d155e69 100644 --- a/mne_lsl/stream/_filters.py +++ b/mne_lsl/stream/_filters.py @@ -1,26 +1,31 @@ from __future__ import annotations # c.f. PEP 563, PEP 649 -from copy import deepcopy from typing import TYPE_CHECKING from warnings import warn import numpy as np -from mne.filter import create_filter, estimate_ringing_samples -from scipy.signal import sosfilt_zi if TYPE_CHECKING: from typing import Any - from numpy.typing import NDArray - - from .._typing import ScalarIntType - class StreamFilter(dict): """Class defining a filter.""" + _ORDER_STR: dict[int, str] = {1: "1st", 2: "2nd"} + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + for key in ("ftype", "order"): + if key in self: + assert key in self["iir_params"] # sanity-check + del self[key] + def __repr__(self): # noqa: D105 - return f"" + order = self._ORDER_STR.get( + self["irr_params"]["order"], f"{self['irr_params']['order']}th" + ) + return f"" def __eq__(self, other: Any): """Equality operator.""" @@ -54,116 +59,3 @@ def __eq__(self, other: Any): def __ne__(self, other: Any): # explicit method required to issue warning """Inequality operator.""" return not self.__eq__(other) - - -def _combine_filters( - filter1: StreamFilter, - filter2: StreamFilter, - picks: NDArray[+ScalarIntType], -) -> StreamFilter: - """Combine 2 filters applied on the same set of channels.""" - assert filter1["sfreq"] == filter2["sfreq"] - # copy is required else we might end-up modifying the items of the filters used in - # the acquisition thread. - filter1 = deepcopy(filter1) - filter2 = deepcopy(filter2) - system = np.vstack((filter1["sos"], filter2["sos"])) - # for 'l_freq', 'h_freq', 'iir_params' we store the filter(s) settings in ordered - # tuples to keep track of the original settings of individual filters. - for key in ("l_freq", "h_freq", "iir_params"): - filter1[key] = list( - (filter1[key],) if not isinstance(filter1[key], tuple) else filter1[key] - ) - filter2[key] = list( - (filter2[key],) if not isinstance(filter2[key], tuple) else filter2[key] - ) - combined_filter = { - "output": "sos", - "padlen": estimate_ringing_samples(system), - "sos": system, - "zi": None, # reset initial conditions on channels combined - "zi_coeff": sosfilt_zi(system)[..., np.newaxis], - "l_freq": tuple(filter1["l_freq"] + filter2["l_freq"]), - "h_freq": tuple(filter1["h_freq"] + filter2["h_freq"]), - "iir_params": tuple(filter1["iir_params"] + filter2["iir_params"]), - "sfreq": filter1["sfreq"], - "picks": picks, - } - return StreamFilter(combined_filter) - - -def _uncombine_filters(filter_: StreamFilter) -> list[StreamFilter]: - """Uncombine a combined filter into its individual components.""" - val = [ - isinstance(filter_[key], tuple) for key in ("l_freq", "h_freq", "iir_params") - ] - if not all(val) and any(val): # sanity-check - raise RuntimeError( - "The combined filter contains keys 'l_freq', 'h_freq' and 'iir_params' as " - "both tuple and non-tuple, which should not be possible. Please contact " - "the developers." - ) - elif all(elt is False for elt in val): - return [filter_] - # instead of trying to un-tangled the 'sos' matrix, we simply create a new filter - # for each individual component. - filters = list() - for lfq, hfq, iir_param in zip( - filter_["l_freq"], filter_["h_freq"], filter_["iir_params"] - ): - filt = create_filter( - data=None, - sfreq=filter_["sfreq"], - l_freq=lfq, - h_freq=hfq, - method="iir", - iir_params=iir_param, - phase="forward", - verbose="CRITICAL", # effectively disable logs - ) - filt.update( - zi_coeff=sosfilt_zi(filt["sos"])[..., np.newaxis], - zi=None, - l_freq=lfq, - h_freq=hfq, - iir_params=iir_param, - sfreq=filter_["sfreq"], - picks=filter_["picks"], - ) - del filt["order"] - del filt["ftype"] - filters.append(StreamFilter(filt)) - return filters - - -def _sanitize_filters( - filters: list[StreamFilter], - filter_: StreamFilter, -) -> list[dict[str, Any]]: - """Sanitize the list of filters to ensure non-overlapping channels.""" - # copy is required else we might end-up modifying the 'picks' item of the filter - # list used in the acquisition thread. - filters = deepcopy(filters) - filter_ = deepcopy(filter_) - additional_filters = [] - for filt in filters: - intersection = np.intersect1d( - filt["picks"], filter_["picks"], assume_unique=True - ) - if intersection.size == 0: - continue # non-overlapping channels - additional_filters.append(_combine_filters(filt, filter_, picks=intersection)) - # reset initial conditions for the overlapping filter - filt["zi"] = None # TODO: instead of reset, select initial conditions. - # remove overlapping channels from both filters - filt["picks"] = np.setdiff1d(filt["picks"], intersection, assume_unique=True) - filter_["picks"] = np.setdiff1d( - filter_["picks"], intersection, assume_unique=True - ) - # prune filters without any channels - filters = [ - filt - for filt in filters + additional_filters + [filter_] - if filt["picks"].size != 0 - ] - return filters diff --git a/mne_lsl/stream/stream_lsl.py b/mne_lsl/stream/stream_lsl.py index 79d83e57c..0daea0c67 100644 --- a/mne_lsl/stream/stream_lsl.py +++ b/mne_lsl/stream/stream_lsl.py @@ -266,7 +266,7 @@ def _acquire(self) -> None: data_filtered, filt["zi"] = sosfilt( filt["sos"], data[:, filt["picks"]], zi=filt["zi"], axis=0 ) - data[:, filt["picks"]] = data_filtered + data[:, filt["picks"]] = data_filtered # operate in-place # roll and update buffers self._buffer = np.roll(self._buffer, -timestamps.size, axis=0) diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index 7d95d8c3e..a1b38daa8 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -1,7 +1,6 @@ from __future__ import annotations # c.f. PEP 563, PEP 649 from copy import deepcopy -from itertools import chain from typing import TYPE_CHECKING import numpy as np @@ -9,18 +8,11 @@ from mne.filter import create_filter from scipy.signal import sosfilt_zi -from mne_lsl.stream._filters import ( - StreamFilter, - _combine_filters, - _sanitize_filters, - _uncombine_filters, -) +from mne_lsl.stream._filters import StreamFilter if TYPE_CHECKING: from typing import Any - from numpy.typing import NDArray - @pytest.fixture(scope="module") def iir_params() -> dict[str, Any]: @@ -34,211 +26,6 @@ def sfreq() -> float: return 1000.0 -@pytest.fixture(scope="module") -def picks() -> NDArray[np.int8]: - """Return a valid selection of channels.""" - return np.arange(0, 10, dtype=np.int8) - - -@pytest.fixture(scope="function") -def filter1( - iir_params: dict[str, Any], sfreq: float, picks: NDArray[np.int8] -) -> StreamFilter: - """Create a filter.""" - l_freq = 1.0 - h_freq = 40.0 - filt = create_filter( - data=None, - sfreq=sfreq, - l_freq=l_freq, - h_freq=h_freq, - method="iir", - iir_params=iir_params, - phase="forward", - verbose="CRITICAL", - ) - filt.update( - zi_coeff=sosfilt_zi(filt["sos"])[..., np.newaxis], - zi=None, - l_freq=l_freq, - h_freq=h_freq, - iir_params=iir_params, - sfreq=sfreq, - picks=picks, - ) - del filt["order"] - del filt["ftype"] - return StreamFilter(filt) - - -@pytest.fixture(scope="function") -def filter2( - iir_params: dict[str, Any], sfreq: float, picks: NDArray[np.int8] -) -> StreamFilter: - """Create a filter.""" - l_freq = 2.0 - h_freq = None - filt = create_filter( - data=None, - sfreq=sfreq, - l_freq=l_freq, - h_freq=h_freq, - method="iir", - iir_params=iir_params, - phase="forward", - verbose="CRITICAL", - ) - filt.update( - zi_coeff=sosfilt_zi(filt["sos"])[..., np.newaxis], - zi=None, - l_freq=l_freq, - h_freq=h_freq, - iir_params=iir_params, - sfreq=sfreq, - picks=picks, - ) - del filt["order"] - del filt["ftype"] - return StreamFilter(filt) - - -@pytest.fixture(scope="function") -def filter3(sfreq: float, picks: NDArray[np.int8]) -> StreamFilter: - """Create a filter.""" - l_freq = None - h_freq = 80.0 - iir_params = dict(order=2, ftype="bessel", output="sos") - filt = create_filter( - data=None, - sfreq=sfreq, - l_freq=l_freq, - h_freq=h_freq, - method="iir", - iir_params=iir_params, - phase="forward", - verbose="CRITICAL", - ) - filt.update( - zi_coeff=sosfilt_zi(filt["sos"])[..., np.newaxis], - zi=None, - l_freq=l_freq, - h_freq=h_freq, - iir_params=iir_params, - sfreq=sfreq, - picks=picks, - ) - del filt["order"] - del filt["ftype"] - return StreamFilter(filt) - - -def test_combine_uncombine_filters( - filter1: StreamFilter, - filter2: StreamFilter, - filter3: StreamFilter, - picks: NDArray[np.int8], -): - """Test (un)combination of filters.""" - # uncombine self - filt = _uncombine_filters(filter1) - assert filter1 == filt[0] - - # combine 2 filters - filt = _combine_filters(filter1, filter2, picks) - assert np.array_equal(filt["sos"], np.vstack((filter1["sos"], filter2["sos"]))) - assert filt["sos"].shape[-1] == 6 - assert filt["l_freq"] == (filter1["l_freq"], filter2["l_freq"]) - assert filt["h_freq"] == (filter1["h_freq"], filter2["h_freq"]) - assert not np.array_equal(filt["zi_coeff"], filter1["zi_coeff"]) - assert not np.array_equal(filt["zi_coeff"], filter2["zi_coeff"]) - assert filt["zi"] is None - assert np.array_equal(filt["picks"], filter1["picks"]) - assert np.array_equal(filt["picks"], filter2["picks"]) - assert filt["sfreq"] == filter1["sfreq"] == filter2["sfreq"] - assert filt["iir_params"] == (filter1["iir_params"], filter2["iir_params"]) - filt1, filt2 = _uncombine_filters(filt) - assert filt1 == filter1 - assert filt2 == filter2 - - # add initial conditions - filter2["zi"] = filter2["zi_coeff"] * 5 - assert filt2 != filter2 - filt = _combine_filters(filter1, filter2, picks) - assert np.array_equal(filt["sos"], np.vstack((filter1["sos"], filter2["sos"]))) - assert filt["sos"].shape[-1] == 6 - assert filt["l_freq"] == (filter1["l_freq"], filter2["l_freq"]) - assert filt["h_freq"] == (filter1["h_freq"], filter2["h_freq"]) - assert not np.array_equal(filt["zi_coeff"], filter1["zi_coeff"]) - assert not np.array_equal(filt["zi_coeff"], filter2["zi_coeff"]) - assert filt["zi"] is None - assert np.array_equal(filt["picks"], filter1["picks"]) - assert np.array_equal(filt["picks"], filter2["picks"]) - assert filt["sfreq"] == filter1["sfreq"] == filter2["sfreq"] - assert filt["iir_params"] == (filter1["iir_params"], filter2["iir_params"]) - filt1, filt2 = _uncombine_filters(filt) - assert filt1 == filter1 - assert filt2 != filter2 - filter2["zi"] = None - assert filt2 == filter2 - - # test with different filter type - filt = _combine_filters(filter1, filter3, picks) - assert np.array_equal(filt["sos"], np.vstack((filter1["sos"], filter3["sos"]))) - assert filt["sos"].shape[-1] == 6 - assert filt["l_freq"] == (filter1["l_freq"], filter3["l_freq"]) - assert filt["h_freq"] == (filter1["h_freq"], filter3["h_freq"]) - assert not np.array_equal(filt["zi_coeff"], filter1["zi_coeff"]) - assert not np.array_equal(filt["zi_coeff"], filter3["zi_coeff"]) - assert filt["zi"] is None - assert np.array_equal(filt["picks"], filter1["picks"]) - assert np.array_equal(filt["picks"], filter3["picks"]) - assert filt["sfreq"] == filter1["sfreq"] == filter3["sfreq"] - assert filt["iir_params"] == (filter1["iir_params"], filter3["iir_params"]) - filt1, filt3 = _uncombine_filters(filt) - assert filt1 == filter1 - assert filt3 == filter3 - - # test combination of 3 filters - filt_ = _combine_filters(filt, filter2, picks) - assert np.array_equal( - filt_["sos"], np.vstack((filter1["sos"], filter3["sos"], filter2["sos"])) - ) - assert np.array_equal(filt_["sos"], np.vstack((filt["sos"], filter2["sos"]))) - assert filt_["sos"].shape[-1] == 6 - assert filt_["l_freq"] == (filter1["l_freq"], filter3["l_freq"], filter2["l_freq"]) - assert filt_["h_freq"] == (filter1["h_freq"], filter3["h_freq"], filter2["h_freq"]) - assert not np.array_equal(filt_["zi_coeff"], filt["zi_coeff"]) - assert not np.array_equal(filt_["zi_coeff"], filter1["zi_coeff"]) - assert not np.array_equal(filt_["zi_coeff"], filter2["zi_coeff"]) - assert not np.array_equal(filt_["zi_coeff"], filter3["zi_coeff"]) - assert filt_["zi"] is None - assert np.array_equal(filt_["picks"], picks) - assert filt_["sfreq"] == filter1["sfreq"] == filter2["sfreq"] == filter3["sfreq"] - assert filt_["iir_params"] == ( - filter1["iir_params"], - filter3["iir_params"], - filter2["iir_params"], - ) - filt1, filt3, filt2 = _uncombine_filters(filt_) - assert filt1 == filter1 - assert filt2 == filter2 # zi already set to None - assert filt3 == filter3 - - -def test_invalid_uncombine_filters(filter1, filter2, picks): - """Test error raising in uncombine filters.""" - filt = _combine_filters(filter1, filter2, picks) - filt["l_freq"] = filt["l_freq"][0] - with pytest.raises(RuntimeError, match="as both tuple and non-tuple"): - _uncombine_filters(filt) - filt["h_freq"] = filt["h_freq"][0] - with pytest.raises(RuntimeError, match="as both tuple and non-tuple"): - _uncombine_filters(filt) - filt["iir_params"] = filt["iir_params"][0] - filt2 = _uncombine_filters(filt) - assert filt == filt2[0] - - @pytest.fixture(scope="function") def filters(iir_params: dict[str, Any], sfreq: float) -> list[StreamFilter]: """Create a list of valid filters.""" @@ -299,70 +86,3 @@ def test_StreamFilter(filters: StreamFilter): assert filter2 != filters[0] # test representation assert f"({filters[0]['l_freq']}, {filters[0]['h_freq']})" in repr(filters[0]) - - -@pytest.fixture( - scope="function", - params=[ - (None, 100, np.arange(30, 40)), - (10, 100, np.arange(30, 40)), - (50, None, np.arange(30, 40)), - (None, 100, np.arange(0, 10)), - (10, 100, np.arange(0, 10)), - (50, None, np.arange(0, 10)), - (None, 100, np.arange(5, 15)), - (10, 100, np.arange(5, 15)), - (50, None, np.arange(5, 15)), - (None, 100, np.arange(5, 10)), - (10, 100, np.arange(5, 10)), - (50, None, np.arange(5, 10)), - (None, 100, np.arange(5, 25)), - (10, 100, np.arange(5, 25)), - (50, None, np.arange(5, 25)), - ], -) -def filter_(request, iir_params: dict[str, Any], sfreq: float) -> StreamFilter: - """Create a filter.""" - l_freq, h_freq, picks = request.param - filt = create_filter( - data=None, - sfreq=sfreq, - l_freq=l_freq, - h_freq=h_freq, - method="iir", - iir_params=iir_params, - phase="forward", - verbose="CRITICAL", - ) - filt.update( - zi_coeff=sosfilt_zi(filt["sos"])[..., np.newaxis], - zi=None, - l_freq=l_freq, - h_freq=h_freq, - iir_params=iir_params, - sfreq=sfreq, - picks=picks, - ) - del filt["order"] - del filt["ftype"] - return StreamFilter(filt) - - -def test_sanitize_filters(filters: list[StreamFilter], filter_: StreamFilter): - """Test clean-up of filter list to ensure non-overlap between channels.""" - # look for overlapping channels - overlap = [np.intersect1d(filt["picks"], filter_["picks"]) for filt in filters] - # sanitize and validate output - filts = _sanitize_filters(filters, filter_) - if all(ol.size == 0 for ol in overlap): - assert filts == filters + [filter_] - else: - picks = list(chain(*(filt["picks"] for filt in filts))) - assert np.unique(picks).size == len(picks) # ensure no more overlap - # find pairs of filters that have been combined - idx = [k for k, ol in enumerate(overlap) if ol.size != 0] - for k in idx: - filt = _combine_filters(filters[k], filter_, overlap[k]) - assert filt in filts - assert filters[k] not in filts - assert filter_ not in filts From 64dcd14d8c44b6b090feb2808b0135006d3aabda Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 4 Mar 2024 17:00:20 +0100 Subject: [PATCH 40/69] simplify stream code --- mne_lsl/stream/_base.py | 19 +---------- mne_lsl/stream/_filters.py | 49 +++++++++++++++++++++++++++- mne_lsl/stream/stream_lsl.py | 2 +- mne_lsl/stream/tests/test_filters.py | 6 ++-- 4 files changed, 53 insertions(+), 23 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index c6f42f6b9..9d3cd2abc 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -12,7 +12,6 @@ from mne.channels import rename_channels from mne.filter import create_filter from mne.utils import check_version -from scipy.signal import sosfilt_zi if check_version("mne", "1.6"): from mne._fiff.constants import FIFF, _ch_unit_mul_named @@ -418,28 +417,12 @@ def filter( iir_params["output"] = "sos" # construct an IIR filter filt = create_filter( - data=None, sfreq=self._info["sfreq"], l_freq=l_freq, h_freq=h_freq, - method="iir", iir_params=iir_params, - phase="forward", - verbose=logger.level if verbose is None else verbose, - ) - # store filter parameters and initial conditions - filt.update( - zi_coeff=sosfilt_zi(filt["sos"])[..., np.newaxis], - zi=None, - l_freq=l_freq, - h_freq=h_freq, - iir_params=iir_params, - sfreq=self._info["sfreq"], - picks=picks, ) - # remove duplicate information - del filt["order"] - del filt["ftype"] + filt.update(picks=picks) # channel selection # add filter to the list of applied filters with self._interrupt_acquisition(): self._filters.append(StreamFilter(filt)) diff --git a/mne_lsl/stream/_filters.py b/mne_lsl/stream/_filters.py index 81d155e69..c43552f59 100644 --- a/mne_lsl/stream/_filters.py +++ b/mne_lsl/stream/_filters.py @@ -4,9 +4,13 @@ from warnings import warn import numpy as np +from mne.filter import create_filter as create_filter_mne +from scipy.signal import sosfilt_zi + +from ..utils._logs import logger if TYPE_CHECKING: - from typing import Any + from typing import Any, Optional class StreamFilter(dict): @@ -59,3 +63,46 @@ def __eq__(self, other: Any): def __ne__(self, other: Any): # explicit method required to issue warning """Inequality operator.""" return not self.__eq__(other) + + +def create_filter( + sfreq: float, + l_freq: Optional[float], + h_freq: Optional[float], + iir_params: dict[str, Any], +) -> dict[str, Any]: + """Create an IIR causal filter. + + Parameters + ---------- + sfreq : float + The sampling ferquency in Hz. + %(l_freq)s + %(h_freq)s + %(iir_params)s + + Returns + ------- + filt : dict + The filter parameters and initial conditions. + """ + filt = create_filter_mne( + data=None, + sfreq=sfreq, + l_freq=l_freq, + h_freq=h_freq, + method="iir", + iir_params=iir_params, + phase="forward", + verbose=logger.level, + ) + # store filter parameters and initial conditions + filt.update( + zi_unit=sosfilt_zi(filt["sos"])[..., np.newaxis], + zi=None, + l_freq=l_freq, + h_freq=h_freq, + iir_params=iir_params, + sfreq=sfreq, + ) + return filt diff --git a/mne_lsl/stream/stream_lsl.py b/mne_lsl/stream/stream_lsl.py index 0daea0c67..f132a47a0 100644 --- a/mne_lsl/stream/stream_lsl.py +++ b/mne_lsl/stream/stream_lsl.py @@ -260,7 +260,7 @@ def _acquire(self) -> None: if filt["zi"] is None: # initial conditions are set to a step response steady-state set # on the mean on the acquisition window (e.g. DC offset for EEGs) - filt["zi"] = filt["zi_coeff"] * np.mean( + filt["zi"] = filt["zi_unit"] * np.mean( data[:, filt["picks"]], axis=0 ) data_filtered, filt["zi"] = sosfilt( diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index a1b38daa8..1b24551bf 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -46,10 +46,10 @@ def filters(iir_params: dict[str, Any], sfreq: float) -> list[StreamFilter]: for lfq, hfq in zip(l_freqs, h_freqs) ] for k, (filt, lfq, hfq, picks_) in enumerate(zip(filters, l_freqs, h_freqs, picks)): - zi_coeff = sosfilt_zi(filt["sos"])[..., np.newaxis] + zi_unit = sosfilt_zi(filt["sos"])[..., np.newaxis] filt.update( - zi_coeff=zi_coeff, - zi=zi_coeff * k, + zi_unit=zi_unit, + zi=zi_unit * k, l_freq=lfq, h_freq=hfq, iir_params=iir_params, From 1ccf2e1055fe1faecb42a26c356457e02c116b7b Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 4 Mar 2024 17:03:01 +0100 Subject: [PATCH 41/69] add entries to docdict to de-duplicate docstrings --- mne_lsl/stream/_base.py | 15 +++------------ mne_lsl/utils/_docs.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index 9d3cd2abc..fbd61b1d5 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -366,19 +366,10 @@ def filter( Parameters ---------- - l_freq : float | None - The lower cutoff frequency. If None, the buffer is only low-passed. - h_freq : float | None - The higher cutoff frequency. If None, the buffer is only high-passed. + %(l_freq)s + %(h_freq)s %(picks_all)s - iir_params : dict | None - Dictionary of parameters to use for IIR filtering. If None, a 4th order - Butterworth will be used. For more information, see - :func:`mne.filter.construct_iir_filter`. - - .. note:: - - The output ``sos`` must be used. The ``ba`` output is not supported. + %(iir_params)s %(verbose)s Returns diff --git a/mne_lsl/utils/_docs.py b/mne_lsl/utils/_docs.py index 2663502bb..40a035ea9 100644 --- a/mne_lsl/utils/_docs.py +++ b/mne_lsl/utils/_docs.py @@ -44,10 +44,27 @@ # -- F --------------------------------------------------------------------------------- # -- G --------------------------------------------------------------------------------- # -- H --------------------------------------------------------------------------------- +docdict["h_freq"] = """h_freq : float | None + The higher cutoff frequency. If None, the buffer is only high-passed.""" + # -- I --------------------------------------------------------------------------------- +docdict["iir_params"] = """ +iir_params : dict | None + Dictionary of parameters to use for IIR filtering. If None, a 4th order + Butterworth will be used. For more information, see + :func:`mne.filter.construct_iir_filter`. + + .. note:: + + The output ``sos`` must be used. The ``ba`` output is not supported.""" + # -- J --------------------------------------------------------------------------------- # -- K --------------------------------------------------------------------------------- # -- L --------------------------------------------------------------------------------- +docdict["l_freq"] = """ +l_freq : float | None + The lower cutoff frequency. If None, the buffer is only low-passed.""" + # -- M --------------------------------------------------------------------------------- # -- N --------------------------------------------------------------------------------- # -- O --------------------------------------------------------------------------------- From 6c8f59c83f0ce049b92f0404b0034b161908d871 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 4 Mar 2024 17:20:14 +0100 Subject: [PATCH 42/69] fix tests for simplification --- mne_lsl/stream/_filters.py | 7 +++- mne_lsl/stream/tests/test_filters.py | 63 +++++++++++++++------------- 2 files changed, 40 insertions(+), 30 deletions(-) diff --git a/mne_lsl/stream/_filters.py b/mne_lsl/stream/_filters.py index c43552f59..a02067b1a 100644 --- a/mne_lsl/stream/_filters.py +++ b/mne_lsl/stream/_filters.py @@ -7,7 +7,7 @@ from mne.filter import create_filter as create_filter_mne from scipy.signal import sosfilt_zi -from ..utils._logs import logger +from ..utils.logs import logger if TYPE_CHECKING: from typing import Any, Optional @@ -29,7 +29,10 @@ def __repr__(self): # noqa: D105 order = self._ORDER_STR.get( self["irr_params"]["order"], f"{self['irr_params']['order']}th" ) - return f"" + return ( + f"" + ) def __eq__(self, other: Any): """Equality operator.""" diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index 1b24551bf..6dcbd2b7a 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -5,10 +5,8 @@ import numpy as np import pytest -from mne.filter import create_filter -from scipy.signal import sosfilt_zi -from mne_lsl.stream._filters import StreamFilter +from mne_lsl.stream._filters import StreamFilter, create_filter if TYPE_CHECKING: from typing import Any @@ -26,44 +24,49 @@ def sfreq() -> float: return 1000.0 -@pytest.fixture(scope="function") +@pytest.fixture(scope="module") def filters(iir_params: dict[str, Any], sfreq: float) -> list[StreamFilter]: """Create a list of valid filters.""" l_freqs = (1, 1, 0.1) h_freqs = (40, 15, None) picks = (np.arange(0, 10), np.arange(10, 20), np.arange(20, 30)) - filters = [ - create_filter( - data=None, + filters = list() + for k, (lfq, hfq, picks_) in enumerate(zip(l_freqs, h_freqs, picks)): + filt = create_filter( sfreq=sfreq, l_freq=lfq, h_freq=hfq, - method="iir", - iir_params=iir_params, - phase="forward", - verbose="CRITICAL", # disable logs - ) - for lfq, hfq in zip(l_freqs, h_freqs) - ] - for k, (filt, lfq, hfq, picks_) in enumerate(zip(filters, l_freqs, h_freqs, picks)): - zi_unit = sosfilt_zi(filt["sos"])[..., np.newaxis] - filt.update( - zi_unit=zi_unit, - zi=zi_unit * k, - l_freq=lfq, - h_freq=hfq, iir_params=iir_params, - sfreq=sfreq, - picks=picks_, ) + filt.update(picks=picks_) + filt["zi"] = k * filt["zi_unit"] del filt["order"] del filt["ftype"] - all_picks = np.hstack([filt["picks"] for filt in filters]) - assert np.unique(all_picks).size == all_picks.size # sanity-check - return [StreamFilter(filt) for filt in filters] + filters.append(StreamFilter(filt)) + return filters + + +def test_StreamFilter(iir_params: dict[str, Any], sfreq: float): + """Test StreamFilter creation.""" + # test deletion of duplicates + filt = create_filter( + sfreq=sfreq, + l_freq=1, + h_freq=101, + iir_params=iir_params, + ) + filt.update(picks=np.arange(5, 15)) + filt = StreamFilter(filt) + assert "order" not in filt + assert "order" in filt["iir_params"] + assert "ftype" not in filt + assert "ftype" in filt["iir_params"] + # test creation from self + filt2 = StreamFilter(filt) + assert filt == filt2 -def test_StreamFilter(filters: StreamFilter): +def test_StreamFilter_comparison(filters: StreamFilter): """Test the StreamFilter class.""" filter2 = deepcopy(filters[0]) assert filter2 == filters[0] @@ -84,5 +87,9 @@ def test_StreamFilter(filters: StreamFilter): filter2 = deepcopy(filters[0]) del filter2["sos"] assert filter2 != filters[0] - # test representation + + +def test_StreamFilter_repr(filters): + """Test the representation.""" assert f"({filters[0]['l_freq']}, {filters[0]['h_freq']})" in repr(filters[0]) + assert filters[0]["iir_params"]["order"] in repr(filters[0]) From e343d8a0a617d266e72ea81f6f8fbeedb8611d92 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 4 Mar 2024 18:09:54 +0100 Subject: [PATCH 43/69] fix typos --- mne_lsl/stream/_filters.py | 2 +- mne_lsl/stream/tests/test_filters.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mne_lsl/stream/_filters.py b/mne_lsl/stream/_filters.py index a02067b1a..44170cdd0 100644 --- a/mne_lsl/stream/_filters.py +++ b/mne_lsl/stream/_filters.py @@ -27,7 +27,7 @@ def __init__(self, *args, **kwargs): def __repr__(self): # noqa: D105 order = self._ORDER_STR.get( - self["irr_params"]["order"], f"{self['irr_params']['order']}th" + self["iir_params"]["order"], f"{self['iir_params']['order']}th" ) return ( f" Date: Mon, 4 Mar 2024 18:15:44 +0100 Subject: [PATCH 44/69] add test for create_filter --- mne_lsl/stream/tests/test_filters.py | 35 +++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index 1aeb0d5a2..ceb78b29e 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -5,11 +5,13 @@ import numpy as np import pytest +from mne.filter import create_filter as create_filter_mne +from numpy.testing import assert_allclose from mne_lsl.stream._filters import StreamFilter, create_filter if TYPE_CHECKING: - from typing import Any + from typing import Any, Optional @pytest.fixture(scope="module") @@ -66,7 +68,7 @@ def test_StreamFilter(iir_params: dict[str, Any], sfreq: float): assert filt == filt2 -def test_StreamFilter_comparison(filters: StreamFilter): +def test_StreamFilter_comparison(filters: list[StreamFilter]): """Test the StreamFilter class.""" filter2 = deepcopy(filters[0]) assert filter2 == filters[0] @@ -89,7 +91,34 @@ def test_StreamFilter_comparison(filters: StreamFilter): assert filter2 != filters[0] -def test_StreamFilter_repr(filters): +def test_StreamFilter_repr(filters: list[StreamFilter]): """Test the representation.""" assert f"({filters[0]['l_freq']}, {filters[0]['h_freq']})" in repr(filters[0]) assert str(filters[0]["iir_params"]["order"]) in repr(filters[0]) + + +@pytest.mark.parametrize("l_freq, h_freq", [(1, 40), (None, 15), (0.1, None)]) +def test_create_filter( + iir_params: dict[str, Any], + sfreq: float, + l_freq: Optional[float], + h_freq: Optional[float], +): + """Test create_filter conformity with MNE.""" + filter1 = create_filter( + sfreq=sfreq, + l_freq=l_freq, + h_freq=h_freq, + iir_params=iir_params, + ) + filter2 = create_filter_mne( + data=None, + sfreq=sfreq, + l_freq=l_freq, + h_freq=h_freq, + method="iir", + iir_params=iir_params, + phase="forward", + verbose="CRITICAL", + ) + assert_allclose(filter1["sos"], filter2["sos"]) From 2f4c70127e0f8c0c98c6cef90457b23fada5f648 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 4 Mar 2024 18:33:25 +0100 Subject: [PATCH 45/69] add filters property and method to delete filters --- mne_lsl/stream/_base.py | 61 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index fbd61b1d5..c210dfc7c 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -27,7 +27,7 @@ from mne.io.pick import _picks_to_idx from mne.channels.channels import SetChannelsMixin -from ..utils._checks import check_type, check_value +from ..utils._checks import check_type, check_value, ensure_int from ..utils._docs import copy_doc, fill_doc from ..utils.logs import logger, verbose from ..utils.meas_info import _HUMAN_UNITS, _set_channel_units @@ -309,6 +309,59 @@ def disconnect(self) -> BaseStream: # This method needs to close any inlet/network object and need to end with # self._reset_variables(). + def del_filter(self, idx: Union[int, list[int], tuple[int], str] = "all") -> None: + """Remove a filter from the list of applied filters. + + Parameters + ---------- + idx : ``'all'``| int | list of int | tuple of int + If the string ``'all'`` (default), remove all filters. If an integer or a + list of integers, remove the filter(s) at the given index(es) from + :py:attr:`Stream.filters`. + + Notes + ----- + When removing a filter, the initial conditions of all the filters applied on + overlapping channels are reset. The initial conditions will be re-estimated as + a step response steady-state. + """ + # validate input + check_type(idx, ("int-like", tuple, list, str), "idx") + if isinstance(idx, str) and idx != "all": + raise ValueError( + "If 'idx' is provided as str, it must be 'all', which will remove all " + "applied filters. Provided '{idx}' is invalid." + ) + elif idx == "all": + idx = np.arange(len(self._filters), dtype=np.uint8) + elif isinstance(idx, (tuple, list)): + for elt in idx: + check_type(elt, ("int-like",), "idx") + idx = np.array(idx, dtype=np.uint8) + else: + idx = np.array([ensure_int(idx, "idx")], dtype=np.uint8) + if not all(0 <= k < len(self._filters) for k in idx): + raise ValueError( + "The index 'idx' must be a positive integer or a list of positive " + "integers not exceeding the number of filters minus 1: " + f"{len(self._filters) - 1}." + ) + # figure out which filter have overlapping channels and will need their initial + # conditions to be reset to a step response steady-state. + picks = np.unique(np.hstack([self._filters[k]["picks"] for k in idx])) + filters2reset = list() + for k, filt in enumerate(self._filters): + if k in idx: + continue # this filter will be deleted + if np.setdiff1d(picks, filt["picks"]).size != 0: + filters2reset.append(k) + # interrupt acquisition and apply changes + with self._interrupt_acquisition(): + for k in filters2reset: + self._filters[k]["zi"] = None + for k in idx[::-1]: + del self._filters[k] + def drop_channels(self, ch_names: Union[str, list[str], tuple[str]]) -> BaseStream: """Drop channel(s). @@ -958,7 +1011,6 @@ def compensation_grade(self) -> Optional[int]: self._check_connected(name="compensation_grade") return super().compensation_grade - # ---------------------------------------------------------------------------------- @property def ch_names(self) -> list[str]: """Name of the channels. @@ -994,6 +1046,11 @@ def dtype(self) -> Optional[DTypeLike]: """Channel format of the stream.""" return getattr(self._buffer, "dtype", None) + @property + def filters(self) -> list[StreamFilter]: + """List of filters applied to the real-time Stream.""" + return self._filters + @property def info(self) -> Info: """Info of the LSL stream. From c102df73e38a7b7f6977cb7d76b2f04a1aaf4132 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 4 Mar 2024 18:41:30 +0100 Subject: [PATCH 46/69] add test placeholder and sort idx for deletion --- mne_lsl/stream/_base.py | 9 +++++++++ mne_lsl/stream/tests/test_stream_lsl.py | 5 +++++ 2 files changed, 14 insertions(+) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index c210dfc7c..22c6ba97c 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -346,6 +346,15 @@ def del_filter(self, idx: Union[int, list[int], tuple[int], str] = "all") -> Non "integers not exceeding the number of filters minus 1: " f"{len(self._filters) - 1}." ) + idx_unique = np.unique(idx) + if idx_unique.size != idx.size: + warn( + "The index 'idx' contains duplicates. Only unique indices will be " + "used.", + RuntimeWarning, + stacklevel=2, + ) + idx = np.sort(idx) # figure out which filter have overlapping channels and will need their initial # conditions to be reset to a step response steady-state. picks = np.unique(np.hstack([self._filters[k]["picks"] for k in idx])) diff --git a/mne_lsl/stream/tests/test_stream_lsl.py b/mne_lsl/stream/tests/test_stream_lsl.py index ec254c07b..2e1a7a926 100644 --- a/mne_lsl/stream/tests/test_stream_lsl.py +++ b/mne_lsl/stream/tests/test_stream_lsl.py @@ -635,3 +635,8 @@ def test_stream_annotations_picks(_mock_lsl_stream_annotations): data, ts = stream.get_data() assert np.count_nonzero(data) == data.size stream.disconnect() + + +def test_stream_filter_deleetion(): + """Test deletion of filters applied to a Stream.""" + pass From 3bf8b4789e17def99b248b772c764ecab119029f Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 4 Mar 2024 19:14:14 +0100 Subject: [PATCH 47/69] add logs and test for deletion [skip ci] --- mne_lsl/stream/_base.py | 20 ++++++++-- mne_lsl/stream/tests/test_stream_lsl.py | 53 ++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 5 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index 22c6ba97c..c10c00168 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -10,7 +10,6 @@ import numpy as np from mne import pick_info, pick_types from mne.channels import rename_channels -from mne.filter import create_filter from mne.utils import check_version if check_version("mne", "1.6"): @@ -31,7 +30,7 @@ from ..utils._docs import copy_doc, fill_doc from ..utils.logs import logger, verbose from ..utils.meas_info import _HUMAN_UNITS, _set_channel_units -from ._filters import StreamFilter +from ._filters import StreamFilter, create_filter if TYPE_CHECKING: from datetime import datetime @@ -325,6 +324,9 @@ def del_filter(self, idx: Union[int, list[int], tuple[int], str] = "all") -> Non overlapping channels are reset. The initial conditions will be re-estimated as a step response steady-state. """ + self._check_connected_and_regular_sampling("del_filter()") + if len(self._filters) == 0: + raise RuntimeError("No filter to remove.") # validate input check_type(idx, ("int-like", tuple, list, str), "idx") if isinstance(idx, str) and idx != "all": @@ -339,6 +341,8 @@ def del_filter(self, idx: Union[int, list[int], tuple[int], str] = "all") -> Non check_type(elt, ("int-like",), "idx") idx = np.array(idx, dtype=np.uint8) else: + # ensure_int is run as a sanity-check, it should not be possible to enter + # this statement without idx as int-like. idx = np.array([ensure_int(idx, "idx")], dtype=np.uint8) if not all(0 <= k < len(self._filters) for k in idx): raise ValueError( @@ -354,7 +358,12 @@ def del_filter(self, idx: Union[int, list[int], tuple[int], str] = "all") -> Non RuntimeWarning, stacklevel=2, ) - idx = np.sort(idx) + idx = np.sort(idx_unique) + logger.info( + "Removing filters at index(es): %s\n%s", + ", ".join([str(k) for k in idx]), + "\n".join([repr(self._filters[k]) for k in idx]), + ) # figure out which filter have overlapping channels and will need their initial # conditions to be reset to a step response steady-state. picks = np.unique(np.hstack([self._filters[k]["picks"] for k in idx])) @@ -364,6 +373,11 @@ def del_filter(self, idx: Union[int, list[int], tuple[int], str] = "all") -> Non continue # this filter will be deleted if np.setdiff1d(picks, filt["picks"]).size != 0: filters2reset.append(k) + if len(filters2reset) != 0: + logger.info( + "The initial conditions will be reset on filters:\n%s", + "\n".join([repr(self._filters[k]) for k in filters2reset]), + ) # interrupt acquisition and apply changes with self._interrupt_acquisition(): for k in filters2reset: diff --git a/mne_lsl/stream/tests/test_stream_lsl.py b/mne_lsl/stream/tests/test_stream_lsl.py index 2e1a7a926..d1a2f88f6 100644 --- a/mne_lsl/stream/tests/test_stream_lsl.py +++ b/mne_lsl/stream/tests/test_stream_lsl.py @@ -637,6 +637,55 @@ def test_stream_annotations_picks(_mock_lsl_stream_annotations): stream.disconnect() -def test_stream_filter_deleetion(): +def test_stream_filter_deletion(mock_lsl_stream): """Test deletion of filters applied to a Stream.""" - pass + # test no filter + stream = Stream(bufsize=2.0).connect() + time.sleep(0.1) + with pytest.raises(RuntimeError, match="No filter to remove."): + stream.del_filter("all") + with pytest.raises(RuntimeError, match="No filter to remove."): + stream.del_filter(0) + # test valid deletion + stream.filter(1, 100, picks=["F7", "F3", "Fz"]) + time.sleep(0.1) + assert len(stream.filters) == 1 + stream.del_filter("all") + assert len(stream.filters) == 0 + stream.filter(1, 100, picks=["F7", "F3", "Fz"]) + time.sleep(0.1) + # test invalid + with pytest.raises(ValueError, match="is provided as str, it must be"): + stream.del_filter("0") + with pytest.raises(ValueError, match="is provided as str, it must be"): + stream.del_filter("0") + with pytest.raises(TypeError, match="must be an instance of int-like"): + stream.del_filter(["0"]) + with pytest.raises(TypeError, match="must be an instance of int-like"): + stream.del_filter(("0",)) + with pytest.raises(TypeError, match="must be an instance of"): + stream.del_filter((lambda x: 0,)) + with pytest.raises(TypeError, match="must be an instance of"): + stream.del_filter(lambda x: 0) + with pytest.raises(ValueError, match="must be a positive integer"): + stream.del_filter(1) + with pytest.warns(RuntimeWarning, match="contains duplicates"): + stream.del_filter((0, 0)) + assert len(stream.filters) == 0 + # test reset of initial conditions + stream.disconnect() + time.sleep(0.1) + stream.connect(acquisition_delay=1.0) + time.sleep(0.1) + stream.filter(1, 100, picks=["F7", "F3", "Fz"]) + stream.filter(20, None, picks=["F7", "F3", "O1"]) + stream.filter(None, 20, picks=["Fz", "O2"]) + assert len(stream.filters) == 3 + assert stream.filters[0]["l_freq"] == 1.0 + assert stream.filters[1]["l_freq"] == 20.0 + assert stream.filters[2]["l_freq"] is None + time.sleep(1.1) + assert all(filt["zi"] is not None for filt in stream.filters) + stream.del_filter(2) + assert stream.filters[0]["zi"] is None + assert stream.filters[1]["zi"] is not None From 32347a94a93513f0987e5aa44a12387901268ed0 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 4 Mar 2024 19:33:54 +0100 Subject: [PATCH 48/69] fix deletion test through logs --- mne_lsl/conftest.py | 3 ++- mne_lsl/stream/_base.py | 2 +- mne_lsl/stream/tests/test_stream_lsl.py | 17 +++++++++-------- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/mne_lsl/conftest.py b/mne_lsl/conftest.py index 4479bffeb..5eed97fb7 100644 --- a/mne_lsl/conftest.py +++ b/mne_lsl/conftest.py @@ -11,7 +11,7 @@ from mne.io import Raw, read_raw_fif from pytest import fixture -from mne_lsl import set_log_level +from mne_lsl import logger, set_log_level from mne_lsl.datasets import testing from mne_lsl.lsl import StreamInlet, StreamOutlet @@ -59,6 +59,7 @@ def pytest_configure(config: Config) -> None: config.addinivalue_line("filterwarnings", warning_line) set_log_level_mne("WARNING") # MNE logger set_log_level("DEBUG") # MNE-lsl logger + logger.propagate = True def pytest_sessionfinish(session, exitstatus) -> None: diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index c10c00168..3fa9aa1e6 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -371,7 +371,7 @@ def del_filter(self, idx: Union[int, list[int], tuple[int], str] = "all") -> Non for k, filt in enumerate(self._filters): if k in idx: continue # this filter will be deleted - if np.setdiff1d(picks, filt["picks"]).size != 0: + if np.intersect1d(filt["picks"], picks).size != 0: filters2reset.append(k) if len(filters2reset) != 0: logger.info( diff --git a/mne_lsl/stream/tests/test_stream_lsl.py b/mne_lsl/stream/tests/test_stream_lsl.py index d1a2f88f6..716588659 100644 --- a/mne_lsl/stream/tests/test_stream_lsl.py +++ b/mne_lsl/stream/tests/test_stream_lsl.py @@ -1,3 +1,4 @@ +import logging import os import platform import re @@ -637,7 +638,7 @@ def test_stream_annotations_picks(_mock_lsl_stream_annotations): stream.disconnect() -def test_stream_filter_deletion(mock_lsl_stream): +def test_stream_filter_deletion(mock_lsl_stream, caplog): """Test deletion of filters applied to a Stream.""" # test no filter stream = Stream(bufsize=2.0).connect() @@ -673,10 +674,6 @@ def test_stream_filter_deletion(mock_lsl_stream): stream.del_filter((0, 0)) assert len(stream.filters) == 0 # test reset of initial conditions - stream.disconnect() - time.sleep(0.1) - stream.connect(acquisition_delay=1.0) - time.sleep(0.1) stream.filter(1, 100, picks=["F7", "F3", "Fz"]) stream.filter(20, None, picks=["F7", "F3", "O1"]) stream.filter(None, 20, picks=["Fz", "O2"]) @@ -684,8 +681,12 @@ def test_stream_filter_deletion(mock_lsl_stream): assert stream.filters[0]["l_freq"] == 1.0 assert stream.filters[1]["l_freq"] == 20.0 assert stream.filters[2]["l_freq"] is None - time.sleep(1.1) + time.sleep(0.5) assert all(filt["zi"] is not None for filt in stream.filters) + caplog.set_level(logging.INFO) + caplog.clear() stream.del_filter(2) - assert stream.filters[0]["zi"] is None - assert stream.filters[1]["zi"] is not None + assert ( + "The initial conditions will be reset on filters:\n" f"{stream.filters[0]}" + ) in caplog.text + stream.disconnect() From 30f8d7336328340088e27a355e51cff6ac6df272 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 4 Mar 2024 19:34:54 +0100 Subject: [PATCH 49/69] better --- mne_lsl/stream/tests/test_stream_lsl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne_lsl/stream/tests/test_stream_lsl.py b/mne_lsl/stream/tests/test_stream_lsl.py index 716588659..12e627e6d 100644 --- a/mne_lsl/stream/tests/test_stream_lsl.py +++ b/mne_lsl/stream/tests/test_stream_lsl.py @@ -687,6 +687,7 @@ def test_stream_filter_deletion(mock_lsl_stream, caplog): caplog.clear() stream.del_filter(2) assert ( - "The initial conditions will be reset on filters:\n" f"{stream.filters[0]}" + f"The initial conditions will be reset on filters:\n{stream.filters[0]}" ) in caplog.text + assert repr(stream.filters[1]) not in caplog.text stream.disconnect() From e9a9a169ca74578145c9e3d0c686d16f494be35f Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 6 Mar 2024 15:05:38 +0100 Subject: [PATCH 50/69] fix x-ref to base class --- mne_lsl/stream/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index 3fa9aa1e6..6dd3ca08c 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -316,7 +316,7 @@ def del_filter(self, idx: Union[int, list[int], tuple[int], str] = "all") -> Non idx : ``'all'``| int | list of int | tuple of int If the string ``'all'`` (default), remove all filters. If an integer or a list of integers, remove the filter(s) at the given index(es) from - :py:attr:`Stream.filters`. + ``Stream.filters``. Notes ----- From a62b9f0b5ea2287f466a964a25b090a5e0f293dc Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 6 Mar 2024 15:05:51 +0100 Subject: [PATCH 51/69] improve StreamFilter instantiation logic --- mne_lsl/stream/_filters.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/mne_lsl/stream/_filters.py b/mne_lsl/stream/_filters.py index 44170cdd0..aff7e95e6 100644 --- a/mne_lsl/stream/_filters.py +++ b/mne_lsl/stream/_filters.py @@ -20,10 +20,26 @@ class StreamFilter(dict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + if "iir_params" not in self: + warn( + "The 'iir_params' key is missing, which is unexpected.", + RuntimeWarning, + stacklevel=2, + ) + self["iir_params"] = dict() for key in ("ftype", "order"): - if key in self: - assert key in self["iir_params"] # sanity-check - del self[key] + if key not in self: + continue + if key not in self["iir_params"]: + self["iir_params"][key] = self[key] + else: + if self[key] != self["iir_params"][key]: + raise RuntimeError( + f"The value of '{key}' in the filter dictionary and in the " + "filter parameters '{iir_params}' is inconsistent. " + f"{self[key]} != {self['iir_params'][key]}." + ) + del self[key] def __repr__(self): # noqa: D105 order = self._ORDER_STR.get( From d0be4c784e98d3fb065c8e15484a26780042140b Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 6 Mar 2024 15:14:52 +0100 Subject: [PATCH 52/69] add tests --- mne_lsl/stream/tests/test_filters.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mne_lsl/stream/tests/test_filters.py b/mne_lsl/stream/tests/test_filters.py index ceb78b29e..9b3e4b2d8 100644 --- a/mne_lsl/stream/tests/test_filters.py +++ b/mne_lsl/stream/tests/test_filters.py @@ -66,6 +66,14 @@ def test_StreamFilter(iir_params: dict[str, Any], sfreq: float): # test creation from self filt2 = StreamFilter(filt) assert filt == filt2 + # test invalid creation + del filt2["iir_params"] + with pytest.warns(RuntimeWarning, match=" 'iir_params' key is missing"): + StreamFilter(filt2) + filt2["iir_params"] = filt["iir_params"] + filt2["order"] = 101 + with pytest.raises(RuntimeError, match="inconsistent"): + StreamFilter(filt2) def test_StreamFilter_comparison(filters: list[StreamFilter]): From 7a15822b03437403030541b279e518c271083d5e Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 6 Mar 2024 15:35:37 +0100 Subject: [PATCH 53/69] add test for channel selection inc filters --- mne_lsl/stream/_base.py | 1 + mne_lsl/stream/tests/test_stream_lsl.py | 29 ++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index 6dd3ca08c..0e2b13f00 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -493,6 +493,7 @@ def filter( # add filter to the list of applied filters with self._interrupt_acquisition(): self._filters.append(StreamFilter(filt)) + return self @copy_doc(ContainsMixin.get_channel_types) def get_channel_types( diff --git a/mne_lsl/stream/tests/test_stream_lsl.py b/mne_lsl/stream/tests/test_stream_lsl.py index 12e627e6d..65b83e0b5 100644 --- a/mne_lsl/stream/tests/test_stream_lsl.py +++ b/mne_lsl/stream/tests/test_stream_lsl.py @@ -641,7 +641,7 @@ def test_stream_annotations_picks(_mock_lsl_stream_annotations): def test_stream_filter_deletion(mock_lsl_stream, caplog): """Test deletion of filters applied to a Stream.""" # test no filter - stream = Stream(bufsize=2.0).connect() + stream = Stream(bufsize=2.0, name=mock_lsl_stream.name).connect() time.sleep(0.1) with pytest.raises(RuntimeError, match="No filter to remove."): stream.del_filter("all") @@ -691,3 +691,30 @@ def test_stream_filter_deletion(mock_lsl_stream, caplog): ) in caplog.text assert repr(stream.filters[1]) not in caplog.text stream.disconnect() + + +def test_stream_filter_picks(mock_lsl_stream): + """Test picks from a StreamFilter.""" + stream = ( + Stream(bufsize=2.0, name=mock_lsl_stream.name) + .connect() + .filter(l_freq=1.0, h_freq=40.0, picks="eeg") + ) + assert len(stream.filters) == 1 + assert_allclose( + stream.filters[0]["picks"], + _picks_to_idx(mock_lsl_stream.info, picks="eeg", exclude=()), + ) + stream.pick(["F7", "F3", "Fz", "F4", "F8"]) # consecutive EEG-only channels + assert_allclose(stream.filters[0]["picks"], np.arange(5)) + stream.pick(["F3", "F4"]) # non-consecutive EEG-only channels + assert_allclose(stream.filters[0]["picks"], np.arange(2)) + stream.disconnect().connect() # reset + stream.filter(l_freq=None, h_freq=100.0, picks=("eeg", "ecg", "eog")) + assert len(stream.filters) == 1 + picks_ = _picks_to_idx( + mock_lsl_stream.info, picks=("eeg", "ecg", "eog"), exclude=() + ) + assert_allclose(stream.filters[0]["picks"], picks_) + stream.drop_channels(["ECG"]) # -2 channel + assert_allclose(stream.filters[0]["picks"], picks_[:-1]) From 5d1cac9a31f45378fc036f0e75fda33c3b6d9949 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 6 Mar 2024 16:01:42 +0100 Subject: [PATCH 54/69] for now, prevent pick after filter --- mne_lsl/stream/_base.py | 19 ++++++++++++------- mne_lsl/stream/tests/test_stream_lsl.py | 1 + 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index 0e2b13f00..156de3728 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -988,7 +988,12 @@ def _pick(self, picks: NDArray[+ScalarIntType]) -> None: if picks_inlet.size == 0: raise RuntimeError( "The requested channel selection would not leave any channel from the " - "LSL Stream." + "Stream." + ) + if len(self._filters) != 0: + raise RuntimeError( + "The channel selection must be done before adding filters to the " + "Stream." ) with self._interrupt_acquisition(): @@ -999,14 +1004,14 @@ def _pick(self, picks: NDArray[+ScalarIntType]) -> None: for ch in self._added_channels[::-1]: if ch not in self.ch_names: self._added_channels.remove(ch) - # remove dropped channels from filters - for filt in self._filters: - # TODO: ensure correct selection of channels. - filt["picks"] = np.intersect1d(filt["picks"], picks, assume_unique=True) - # TODO: don't reset, select initial conditions. - filt["zi"] = None self._filters = [filt for filt in self._filters if filt["picks"].size != 0] + def _map_picks_to_buffer(self, picks, picks2map): + mask = -np.ones(self._buffer.shape[1], dtype=int) + mask[picks2map] = picks2map + mask = mask[picks] + return mask[np.where(mask != -1)] + @abstractmethod def _reset_variables(self) -> None: """Reset variables define after connection.""" diff --git a/mne_lsl/stream/tests/test_stream_lsl.py b/mne_lsl/stream/tests/test_stream_lsl.py index 65b83e0b5..8b8996d53 100644 --- a/mne_lsl/stream/tests/test_stream_lsl.py +++ b/mne_lsl/stream/tests/test_stream_lsl.py @@ -693,6 +693,7 @@ def test_stream_filter_deletion(mock_lsl_stream, caplog): stream.disconnect() +@pytest.mark.skip(reason="Not yet implemented.") def test_stream_filter_picks(mock_lsl_stream): """Test picks from a StreamFilter.""" stream = ( From e8c8a7ba21590bc304f27f2b6987564184cf1e6c Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 6 Mar 2024 16:43:00 +0100 Subject: [PATCH 55/69] rm snippet --- mne_lsl/stream/_base.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index 156de3728..132e0b555 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -1006,12 +1006,6 @@ def _pick(self, picks: NDArray[+ScalarIntType]) -> None: self._added_channels.remove(ch) self._filters = [filt for filt in self._filters if filt["picks"].size != 0] - def _map_picks_to_buffer(self, picks, picks2map): - mask = -np.ones(self._buffer.shape[1], dtype=int) - mask[picks2map] = picks2map - mask = mask[picks] - return mask[np.where(mask != -1)] - @abstractmethod def _reset_variables(self) -> None: """Reset variables define after connection.""" From c90b749330f05a0c366b72d18509e4f2775b8dfe Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 6 Mar 2024 16:45:36 +0100 Subject: [PATCH 56/69] better fixture names --- mne_lsl/stream/tests/test_stream_lsl.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/mne_lsl/stream/tests/test_stream_lsl.py b/mne_lsl/stream/tests/test_stream_lsl.py index 8b8996d53..632e1c99e 100644 --- a/mne_lsl/stream/tests/test_stream_lsl.py +++ b/mne_lsl/stream/tests/test_stream_lsl.py @@ -49,7 +49,7 @@ def acquisition_delay(request): @pytest.fixture(scope="function") -def _mock_lsl_stream_int(request): +def mock_lsl_stream_int(request): """Create a mock LSL stream streaming the channel number continuously.""" # nest the PlayerLSL import to first write the temporary LSL configuration file from mne_lsl.player import PlayerLSL # noqa: E402 @@ -63,7 +63,7 @@ def _mock_lsl_stream_int(request): @pytest.fixture(scope="function") -def _mock_lsl_stream_annotations(raw_annotations, request): +def mock_lsl_stream_annotations(raw_annotations, request): """Create a mock LSL stream streaming the channel number continuously.""" # nest the PlayerLSL import to first write the temporary LSL configuration file from mne_lsl.player import PlayerLSL # noqa: E402 @@ -476,9 +476,9 @@ def test_stream_invalid_interrupt(mock_lsl_stream): pass -def test_stream_rereference(_mock_lsl_stream_int, acquisition_delay): +def test_stream_rereference(mock_lsl_stream_int, acquisition_delay): """Test re-referencing an EEG-like stream.""" - stream = Stream(bufsize=0.4, name=_mock_lsl_stream_int.name) + stream = Stream(bufsize=0.4, name=mock_lsl_stream_int.name) stream.connect(acquisition_delay=acquisition_delay) time.sleep(0.1) # give a bit of time to slower CIs assert stream.n_new_samples > 0 @@ -490,7 +490,7 @@ def test_stream_rereference(_mock_lsl_stream_int, acquisition_delay): data_ref = np.full(data.shape, np.arange(data.shape[0]).reshape(-1, 1)) data_ref -= data_ref[1, :] assert_allclose(data, data_ref) - _sleep_until_new_data(acquisition_delay, _mock_lsl_stream_int) + _sleep_until_new_data(acquisition_delay, mock_lsl_stream_int) data, _ = stream.get_data() assert_allclose(data, data_ref) @@ -524,15 +524,15 @@ def test_stream_rereference(_mock_lsl_stream_int, acquisition_delay): data_ref[-1, :] = np.zeros(data.shape[1]) data_ref -= data_ref[[1, 2], :].mean(axis=0, keepdims=True) assert_allclose(data, data_ref) - _sleep_until_new_data(stream._acquisition_delay, _mock_lsl_stream_int) + _sleep_until_new_data(stream._acquisition_delay, mock_lsl_stream_int) data, _ = stream.get_data() assert_allclose(data, data_ref) stream.disconnect() -def test_stream_rereference_average(_mock_lsl_stream_int): +def test_stream_rereference_average(mock_lsl_stream_int): """Test average re-referencing schema.""" - stream = Stream(bufsize=0.4, name=_mock_lsl_stream_int.name) + stream = Stream(bufsize=0.4, name=mock_lsl_stream_int.name) stream.connect() time.sleep(0.1) # give a bit of time to slower CIs stream.set_channel_types({"2": "ecg"}) # channels: 0, 1, 2, 3, 4 @@ -543,7 +543,7 @@ def test_stream_rereference_average(_mock_lsl_stream_int): ) data_ref[-2:, :] += 1 assert_allclose(data, data_ref) - _sleep_until_new_data(stream._acquisition_delay, _mock_lsl_stream_int) + _sleep_until_new_data(stream._acquisition_delay, mock_lsl_stream_int) data, _ = stream.get_data(picks="eeg") assert_allclose(data, data_ref) @@ -628,10 +628,9 @@ def test_stream_irregularly_sampled(close_io): close_io() -def test_stream_annotations_picks(_mock_lsl_stream_annotations): +def test_stream_annotations_picks(mock_lsl_stream_annotations): """Test sub-selection of annotations.""" - stream = Stream(bufsize=5, stype="annotations").connect() # test chaining as-well - stream.pick("test1") # most-present annotations + stream = Stream(bufsize=5, stype="annotations").connect().pick("test1") time.sleep(5) # acquire data data, ts = stream.get_data() assert np.count_nonzero(data) == data.size From d6c4f32ba6d547e7a2cc5d1717fb03b64a7cc630 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 6 Mar 2024 16:57:37 +0100 Subject: [PATCH 57/69] fix type hnits --- mne_lsl/conftest.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mne_lsl/conftest.py b/mne_lsl/conftest.py index 5eed97fb7..4e113e083 100644 --- a/mne_lsl/conftest.py +++ b/mne_lsl/conftest.py @@ -8,7 +8,7 @@ import numpy as np from mne import Annotations from mne import set_log_level as set_log_level_mne -from mne.io import Raw, read_raw_fif +from mne.io import read_raw_fif from pytest import fixture from mne_lsl import logger, set_log_level @@ -18,6 +18,7 @@ if TYPE_CHECKING: from pathlib import Path + from mne.io import BaseRaw from pytest import Config # Set debug logging in LSL, e.g.: @@ -112,13 +113,13 @@ def fname(tmp_path_factory) -> Path: @fixture(scope="function") -def raw(fname: Path) -> Raw: +def raw(fname: Path) -> BaseRaw: """Return the raw file corresponding to fname.""" return read_raw_fif(fname, preload=True) @fixture(scope="function") -def raw_annotations(raw: Raw) -> Raw: +def raw_annotations(raw: BaseRaw) -> BaseRaw: """Return a raw file with annotations.""" annotations = Annotations( onset=[0.1, 0.4, 0.5, 0.8, 0.95, 1.1, 1.3], From f9918f98bf97e1dda8c9e1b3d2ed26df8e401c58 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 6 Mar 2024 17:07:56 +0100 Subject: [PATCH 58/69] filtersadd fixture to test --- mne_lsl/stream/tests/test_stream_lsl.py | 30 +++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/mne_lsl/stream/tests/test_stream_lsl.py b/mne_lsl/stream/tests/test_stream_lsl.py index 632e1c99e..42ae4de29 100644 --- a/mne_lsl/stream/tests/test_stream_lsl.py +++ b/mne_lsl/stream/tests/test_stream_lsl.py @@ -1,9 +1,12 @@ +from __future__ import annotations # c.f. PEP 563, PEP 649 + import logging import os import platform import re import time from datetime import datetime, timezone +from typing import TYPE_CHECKING import numpy as np import pytest @@ -21,13 +24,13 @@ from mne.io.constants import FIFF from mne.io.pick import _picks_to_idx -from mne_lsl import logger from mne_lsl.lsl import StreamInfo, StreamOutlet from mne_lsl.stream import StreamLSL as Stream from mne_lsl.utils._tests import match_stream_and_raw_data from mne_lsl.utils.logs import _use_log_level -logger.propagate = True +if TYPE_CHECKING: + from mne.io import BaseRaw bad_gh_macos = pytest.mark.skipif( @@ -72,6 +75,29 @@ def mock_lsl_stream_annotations(raw_annotations, request): yield player +def raw_sinusoids() -> BaseRaw: + """Create a raw object with sinusoids.""" + times = np.linspace(0, 10, 10001) # 1000 Hz + data1 = np.sin(2 * np.pi * 10 * times) + np.sin(2 * np.pi * 30 * times) + data2 = np.sin(2 * np.pi * 30 * times) + np.sin(2 * np.pi * 50 * times) + data3 = np.sin(2 * np.pi * 30 * times) + np.sin(2 * np.pi * 100 * times) + data = np.vstack([data1, data2, data3]) + info = create_info( + ch_names=["10-30", "30-50", "30-100"], sfreq=1000, ch_types="eeg" + ) + return RawArray(data, info) + + +@pytest.fixture(scope="function") +def mock_lsl_stream_sinusoids(raw_sinusoids, request): + """Create a mock LSL stream streaming sinusoids.""" + # nest the PlayerLSL import to first write the temporary LSL configuration file + from mne_lsl.player import PlayerLSL + + with PlayerLSL(raw_sinusoids, name=f"P_{request.node.name}") as player: + yield player + + def test_stream(mock_lsl_stream, acquisition_delay, raw): """Test a valid Stream.""" # test connect/disconnect From 1da11ebbdad934b6852048166a2df2f47332fd55 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 6 Mar 2024 18:05:31 +0100 Subject: [PATCH 59/69] add test --- mne_lsl/stream/tests/test_stream_lsl.py | 30 +++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/mne_lsl/stream/tests/test_stream_lsl.py b/mne_lsl/stream/tests/test_stream_lsl.py index 42ae4de29..b12378e47 100644 --- a/mne_lsl/stream/tests/test_stream_lsl.py +++ b/mne_lsl/stream/tests/test_stream_lsl.py @@ -16,6 +16,8 @@ from mne.io import RawArray from mne.utils import check_version from numpy.testing import assert_allclose +from scipy.fft import fft, fftfreq +from scipy.signal import find_peaks if check_version("mne", "1.6"): from mne._fiff.constants import FIFF @@ -75,9 +77,10 @@ def mock_lsl_stream_annotations(raw_annotations, request): yield player +@pytest.fixture(scope="function") def raw_sinusoids() -> BaseRaw: """Create a raw object with sinusoids.""" - times = np.linspace(0, 10, 10001) # 1000 Hz + times = np.arange(0, 2, 1 / 1000) data1 = np.sin(2 * np.pi * 10 * times) + np.sin(2 * np.pi * 30 * times) data2 = np.sin(2 * np.pi * 30 * times) + np.sin(2 * np.pi * 50 * times) data3 = np.sin(2 * np.pi * 30 * times) + np.sin(2 * np.pi * 100 * times) @@ -583,7 +586,7 @@ def test_stream_rereference_average(mock_lsl_stream_int): data_ref[-2:, :] += 1 data_ref -= data_ref.mean(axis=0, keepdims=True) assert_allclose(data, data_ref) - _sleep_until_new_data(stream._acquisition_delay, _mock_lsl_stream_int) + _sleep_until_new_data(stream._acquisition_delay, mock_lsl_stream_int) data, _ = stream.get_data(picks="eeg") assert_allclose(data, data_ref) stream.disconnect() @@ -744,3 +747,26 @@ def test_stream_filter_picks(mock_lsl_stream): assert_allclose(stream.filters[0]["picks"], picks_) stream.drop_channels(["ECG"]) # -2 channel assert_allclose(stream.filters[0]["picks"], picks_[:-1]) + + +def test_stream_filter(mock_lsl_stream_sinusoids, raw_sinusoids): + """Test stream filters.""" + freqs = fftfreq(raw_sinusoids.times.size, 1 / raw_sinusoids.info["sfreq"]) + idx = np.where(0 <= freqs)[0] + freqs = freqs[idx] + fft_orig = np.abs(fft(raw_sinusoids.get_data(), axis=-1)[:, idx]) + # extract peaks + assert fft_orig.shape[0] == len(raw_sinusoids.ch_names) + assert fft_orig.shape[0] == len(mock_lsl_stream_sinusoids.ch_names) + heights_orig = dict() + for k in range(fft_orig.shape[0]): + peaks, _ = find_peaks(fft_orig[k, :], height=100) # peak height is 1000 + fqs = [int(elt) for elt in raw_sinusoids.ch_names[k].split("-")] + assert_allclose(freqs[peaks], fqs, atol=0.1) + heights_orig[k] = dict(idx=peaks, heights=fft_orig[k, peaks]) + # test unfiltered data + stream = Stream(bufsize=2.0, name=mock_lsl_stream_sinusoids.name).connect() + time.sleep(2.1) + fft_ = np.abs(fft(stream.get_data()[0], axis=-1)[:, idx]) + for ch, ch_height in heights_orig.items(): + assert_allclose(fft_[ch, ch_height["idx"]], ch_height["heights"]) From 5f0868e8612d04380d70e6bc22216076b689ec75 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Thu, 7 Mar 2024 15:42:59 +0100 Subject: [PATCH 60/69] add tests --- mne_lsl/stream/tests/test_stream_lsl.py | 73 ++++++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/mne_lsl/stream/tests/test_stream_lsl.py b/mne_lsl/stream/tests/test_stream_lsl.py index b12378e47..bb24fccf5 100644 --- a/mne_lsl/stream/tests/test_stream_lsl.py +++ b/mne_lsl/stream/tests/test_stream_lsl.py @@ -747,6 +747,7 @@ def test_stream_filter_picks(mock_lsl_stream): assert_allclose(stream.filters[0]["picks"], picks_) stream.drop_channels(["ECG"]) # -2 channel assert_allclose(stream.filters[0]["picks"], picks_[:-1]) + stream.disconnect() def test_stream_filter(mock_lsl_stream_sinusoids, raw_sinusoids): @@ -769,4 +770,74 @@ def test_stream_filter(mock_lsl_stream_sinusoids, raw_sinusoids): time.sleep(2.1) fft_ = np.abs(fft(stream.get_data()[0], axis=-1)[:, idx]) for ch, ch_height in heights_orig.items(): - assert_allclose(fft_[ch, ch_height["idx"]], ch_height["heights"]) + assert_allclose(fft_[ch, ch_height["idx"]], ch_height["heights"], rtol=0.01) + # test filtering + stream.filter(5, 15, picks="10-30") + time.sleep(2.1) + fft_ = np.abs(fft(stream.get_data()[0], axis=-1)[:, idx]) + for ch, ch_height in heights_orig.items(): + if ch == 0: # 10 Hz retained, 30 Hz removed + assert fft_[ch, ch_height["idx"]][1] < 0.1 * ch_height["heights"][1] + assert_allclose( + fft_[ch, ch_height["idx"]][0], ch_height["heights"][0], rtol=0.01 + ) + else: + assert_allclose(fft_[ch, ch_height["idx"]], ch_height["heights"], rtol=0.01) + # test removing filter + stream.del_filter(0) + time.sleep(2.1) + fft_ = np.abs(fft(stream.get_data()[0], axis=-1)[:, idx]) + for ch, ch_height in heights_orig.items(): + assert_allclose(fft_[ch, ch_height["idx"]], ch_height["heights"], rtol=0.01) + # test adding multiple filters + stream.filter(20, 70, picks="eeg") + time.sleep(2.1) + fft_ = np.abs(fft(stream.get_data()[0], axis=-1)[:, idx]) + for ch, ch_height in heights_orig.items(): + if ch == 0: # 10 Hz removed, 30 Hz retained + assert fft_[ch, ch_height["idx"]][0] < 0.1 * ch_height["heights"][0] + assert_allclose( + fft_[ch, ch_height["idx"]][1], ch_height["heights"][1], rtol=0.01 + ) + elif ch == 1: # 30 Hz retained, 50 Hz retained + assert_allclose(fft_[ch, ch_height["idx"]], ch_height["heights"], rtol=0.01) + elif ch == 2: # 30 Hz retained, 100 Hz removed (but not as much attenuation) + assert fft_[ch, ch_height["idx"]][1] < 0.15 * ch_height["heights"][1] + assert_allclose( + fft_[ch, ch_height["idx"]][0], ch_height["heights"][0], rtol=0.01 + ) + stream.filter(40, 60, picks="30-50") # second filter + time.sleep(2.1) + fft_ = np.abs(fft(stream.get_data()[0], axis=-1)[:, idx]) + for ch, ch_height in heights_orig.items(): + if ch == 0: # 10 Hz removed, 30 Hz retained + assert fft_[ch, ch_height["idx"]][0] < 0.1 * ch_height["heights"][0] + assert_allclose( + fft_[ch, ch_height["idx"]][1], ch_height["heights"][1], rtol=0.01 + ) + elif ch == 1: # 30 Hz removed, 50 Hz retained + assert fft_[ch, ch_height["idx"]][0] < 0.1 * ch_height["heights"][0] + assert_allclose( + fft_[ch, ch_height["idx"]][1], ch_height["heights"][1], rtol=0.01 + ) + elif ch == 2: # 30 Hz retained, 100 Hz removed + assert_allclose( + fft_[ch, ch_height["idx"]][0], ch_height["heights"][0], rtol=0.01 + ) + assert fft_[ch, ch_height["idx"]][1] < 0.15 * ch_height["heights"][1] + stream.filter(40, 60, picks="eeg") # third filter + time.sleep(2.1) + fft_ = np.abs(fft(stream.get_data()[0], axis=-1)[:, idx]) + for ch, ch_height in heights_orig.items(): + if ch == 0: # 10 Hz removed, 30 Hz removed + assert fft_[ch, ch_height["idx"]][0] < 0.1 * ch_height["heights"][0] + assert fft_[ch, ch_height["idx"]][0] < 0.1 * ch_height["heights"][0] + elif ch == 1: # 30 Hz removed, 50 Hz retained + assert fft_[ch, ch_height["idx"]][0] < 0.1 * ch_height["heights"][0] + assert_allclose( + fft_[ch, ch_height["idx"]][1], ch_height["heights"][1], rtol=0.01 + ) + elif ch == 2: # 30 Hz removed, 100 Hz removed + assert fft_[ch, ch_height["idx"]][0] < 0.1 * ch_height["heights"][0] + assert fft_[ch, ch_height["idx"]][1] < 0.15 * ch_height["heights"][1] + stream.disconnect() From 9bfa931ec6e8c433de07c6c88923cfae6cd195d0 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Thu, 7 Mar 2024 15:45:26 +0100 Subject: [PATCH 61/69] fix typos --- mne_lsl/utils/_docs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne_lsl/utils/_docs.py b/mne_lsl/utils/_docs.py index 40a035ea9..0f9423542 100644 --- a/mne_lsl/utils/_docs.py +++ b/mne_lsl/utils/_docs.py @@ -45,7 +45,7 @@ # -- G --------------------------------------------------------------------------------- # -- H --------------------------------------------------------------------------------- docdict["h_freq"] = """h_freq : float | None - The higher cutoff frequency. If None, the buffer is only high-passed.""" + The higher cutoff frequency. If None, the buffer is only high-passed.""" # -- I --------------------------------------------------------------------------------- docdict["iir_params"] = """ @@ -63,7 +63,7 @@ # -- L --------------------------------------------------------------------------------- docdict["l_freq"] = """ l_freq : float | None - The lower cutoff frequency. If None, the buffer is only low-passed.""" + The lower cutoff frequency. If None, the buffer is only low-passed.""" # -- M --------------------------------------------------------------------------------- # -- N --------------------------------------------------------------------------------- From dc684c4bea0293327cebe1d35d155a42b450b146 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Thu, 7 Mar 2024 16:18:13 +0100 Subject: [PATCH 62/69] better rtol --- mne_lsl/stream/tests/test_stream_lsl.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/mne_lsl/stream/tests/test_stream_lsl.py b/mne_lsl/stream/tests/test_stream_lsl.py index bb24fccf5..276676bce 100644 --- a/mne_lsl/stream/tests/test_stream_lsl.py +++ b/mne_lsl/stream/tests/test_stream_lsl.py @@ -770,7 +770,7 @@ def test_stream_filter(mock_lsl_stream_sinusoids, raw_sinusoids): time.sleep(2.1) fft_ = np.abs(fft(stream.get_data()[0], axis=-1)[:, idx]) for ch, ch_height in heights_orig.items(): - assert_allclose(fft_[ch, ch_height["idx"]], ch_height["heights"], rtol=0.01) + assert_allclose(fft_[ch, ch_height["idx"]], ch_height["heights"], rtol=0.05) # test filtering stream.filter(5, 15, picks="10-30") time.sleep(2.1) @@ -779,16 +779,16 @@ def test_stream_filter(mock_lsl_stream_sinusoids, raw_sinusoids): if ch == 0: # 10 Hz retained, 30 Hz removed assert fft_[ch, ch_height["idx"]][1] < 0.1 * ch_height["heights"][1] assert_allclose( - fft_[ch, ch_height["idx"]][0], ch_height["heights"][0], rtol=0.01 + fft_[ch, ch_height["idx"]][0], ch_height["heights"][0], rtol=0.05 ) else: - assert_allclose(fft_[ch, ch_height["idx"]], ch_height["heights"], rtol=0.01) + assert_allclose(fft_[ch, ch_height["idx"]], ch_height["heights"], rtol=0.05) # test removing filter stream.del_filter(0) time.sleep(2.1) fft_ = np.abs(fft(stream.get_data()[0], axis=-1)[:, idx]) for ch, ch_height in heights_orig.items(): - assert_allclose(fft_[ch, ch_height["idx"]], ch_height["heights"], rtol=0.01) + assert_allclose(fft_[ch, ch_height["idx"]], ch_height["heights"], rtol=0.05) # test adding multiple filters stream.filter(20, 70, picks="eeg") time.sleep(2.1) @@ -797,14 +797,14 @@ def test_stream_filter(mock_lsl_stream_sinusoids, raw_sinusoids): if ch == 0: # 10 Hz removed, 30 Hz retained assert fft_[ch, ch_height["idx"]][0] < 0.1 * ch_height["heights"][0] assert_allclose( - fft_[ch, ch_height["idx"]][1], ch_height["heights"][1], rtol=0.01 + fft_[ch, ch_height["idx"]][1], ch_height["heights"][1], rtol=0.05 ) elif ch == 1: # 30 Hz retained, 50 Hz retained - assert_allclose(fft_[ch, ch_height["idx"]], ch_height["heights"], rtol=0.01) + assert_allclose(fft_[ch, ch_height["idx"]], ch_height["heights"], rtol=0.05) elif ch == 2: # 30 Hz retained, 100 Hz removed (but not as much attenuation) assert fft_[ch, ch_height["idx"]][1] < 0.15 * ch_height["heights"][1] assert_allclose( - fft_[ch, ch_height["idx"]][0], ch_height["heights"][0], rtol=0.01 + fft_[ch, ch_height["idx"]][0], ch_height["heights"][0], rtol=0.05 ) stream.filter(40, 60, picks="30-50") # second filter time.sleep(2.1) @@ -813,16 +813,16 @@ def test_stream_filter(mock_lsl_stream_sinusoids, raw_sinusoids): if ch == 0: # 10 Hz removed, 30 Hz retained assert fft_[ch, ch_height["idx"]][0] < 0.1 * ch_height["heights"][0] assert_allclose( - fft_[ch, ch_height["idx"]][1], ch_height["heights"][1], rtol=0.01 + fft_[ch, ch_height["idx"]][1], ch_height["heights"][1], rtol=0.05 ) elif ch == 1: # 30 Hz removed, 50 Hz retained assert fft_[ch, ch_height["idx"]][0] < 0.1 * ch_height["heights"][0] assert_allclose( - fft_[ch, ch_height["idx"]][1], ch_height["heights"][1], rtol=0.01 + fft_[ch, ch_height["idx"]][1], ch_height["heights"][1], rtol=0.05 ) elif ch == 2: # 30 Hz retained, 100 Hz removed assert_allclose( - fft_[ch, ch_height["idx"]][0], ch_height["heights"][0], rtol=0.01 + fft_[ch, ch_height["idx"]][0], ch_height["heights"][0], rtol=0.05 ) assert fft_[ch, ch_height["idx"]][1] < 0.15 * ch_height["heights"][1] stream.filter(40, 60, picks="eeg") # third filter @@ -835,7 +835,7 @@ def test_stream_filter(mock_lsl_stream_sinusoids, raw_sinusoids): elif ch == 1: # 30 Hz removed, 50 Hz retained assert fft_[ch, ch_height["idx"]][0] < 0.1 * ch_height["heights"][0] assert_allclose( - fft_[ch, ch_height["idx"]][1], ch_height["heights"][1], rtol=0.01 + fft_[ch, ch_height["idx"]][1], ch_height["heights"][1], rtol=0.05 ) elif ch == 2: # 30 Hz removed, 100 Hz removed assert fft_[ch, ch_height["idx"]][0] < 0.1 * ch_height["heights"][0] From b969bd331145f9a073b51b0314a2024a89672e39 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Thu, 7 Mar 2024 20:57:31 +0100 Subject: [PATCH 63/69] add note --- tutorials/00_introduction.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tutorials/00_introduction.py b/tutorials/00_introduction.py index d571bc088..445d1eea8 100644 --- a/tutorials/00_introduction.py +++ b/tutorials/00_introduction.py @@ -159,6 +159,15 @@ # directly on the ring buffer. For instance, we can select the EEG channels, add the # missing reference channel and re-reference using a common average referencing scheme # which will reduce the ring buffer to 64 channels. +# +# .. note:: +# +# By design, once a re-referencing operation is performed or if at least one filter +# is applied, it is not possible anymore to select a subset of channels with the +# methods :meth:`~mne_lsl.stream.StreamLSL.pick` or +# :meth:`~mne_lsl.stream.StreamLSL.drop_channels`. Note that the re-referencing is +# not reversible while filters can be removed with the method +# :meth:`~mne_lsl.stream.StreamLSL.del_filter`. stream.pick("eeg") # channel selection assert "CPz" not in stream.ch_names # reference absent from the data stream From 0669f402f04771a7cc2629e17cda165378bd8da8 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Thu, 7 Mar 2024 21:22:31 +0100 Subject: [PATCH 64/69] rm line about filters in picks --- mne_lsl/stream/_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index 132e0b555..ed4dbec3b 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -1004,7 +1004,6 @@ def _pick(self, picks: NDArray[+ScalarIntType]) -> None: for ch in self._added_channels[::-1]: if ch not in self.ch_names: self._added_channels.remove(ch) - self._filters = [filt for filt in self._filters if filt["picks"].size != 0] @abstractmethod def _reset_variables(self) -> None: From addef55f52e42c9a7bd52a8551941d4c0e2402d5 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Thu, 7 Mar 2024 21:22:57 +0100 Subject: [PATCH 65/69] better --- mne_lsl/stream/_base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index ed4dbec3b..f38174a91 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -983,7 +983,6 @@ def _pick(self, picks: NDArray[+ScalarIntType]) -> None: "The channel selection must be done before adding a re-refenrecing " "schema with Stream.set_eeg_reference()." ) - picks_inlet = picks[np.where(picks < self._picks_inlet.size)[0]] if picks_inlet.size == 0: raise RuntimeError( @@ -995,7 +994,6 @@ def _pick(self, picks: NDArray[+ScalarIntType]) -> None: "The channel selection must be done before adding filters to the " "Stream." ) - with self._interrupt_acquisition(): self._info = pick_info(self._info, picks, verbose=logger.level) self._picks_inlet = self._picks_inlet[picks_inlet] From 2d221915f3456a0c3eec5543fa2323cdfdf351fd Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Thu, 7 Mar 2024 21:23:49 +0100 Subject: [PATCH 66/69] add dtype --- mne_lsl/stream/_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index f38174a91..adc53372f 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -1068,7 +1068,9 @@ def dtype(self) -> Optional[DTypeLike]: @property def filters(self) -> list[StreamFilter]: - """List of filters applied to the real-time Stream.""" + """List of filters applied to the real-time Stream. + + :type: :class:`list` of ```StreamFilter``""" return self._filters @property From b1a3a4f1aca4e8bddf04a5afbcf86ef6e3eb4903 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Thu, 7 Mar 2024 21:24:01 +0100 Subject: [PATCH 67/69] fix --- mne_lsl/stream/_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index adc53372f..82b45503b 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -1070,7 +1070,8 @@ def dtype(self) -> Optional[DTypeLike]: def filters(self) -> list[StreamFilter]: """List of filters applied to the real-time Stream. - :type: :class:`list` of ```StreamFilter``""" + :type: :class:`list` of ```StreamFilter`` + """ return self._filters @property From fce0ad825f91d0bfe7eb30cfbf8dd38aa2b6fe59 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Thu, 7 Mar 2024 21:25:24 +0100 Subject: [PATCH 68/69] fix docstrings --- mne_lsl/stream/_base.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/mne_lsl/stream/_base.py b/mne_lsl/stream/_base.py index 82b45503b..fe700ecf3 100644 --- a/mne_lsl/stream/_base.py +++ b/mne_lsl/stream/_base.py @@ -431,8 +431,7 @@ def filter( """Filter the stream with an IIR causal filter. Once a filter is applied, the buffer is updated in real-time with the filtered - data. It is not possible to remove an applied filter. It is possible to apply - more than one filter. + data. It is possible to apply more than one filter. .. code-block:: python @@ -452,12 +451,6 @@ def filter( ------- stream : instance of ``Stream`` The stream instance modified in-place. - - Notes - ----- - Adding a filter on channels already filtered will reset the initial conditions - of those channels. The initial conditions will be re-estimated as a step - response steady-state to the combination of both filters. """ self._check_connected_and_regular_sampling("filter()") # validate the arguments and ensure 'sos' output From 25c6ec94556308a2e51992549a328e38d9a2550c Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Thu, 7 Mar 2024 21:27:24 +0100 Subject: [PATCH 69/69] rm test --- mne_lsl/stream/tests/test_stream_lsl.py | 29 ------------------------- 1 file changed, 29 deletions(-) diff --git a/mne_lsl/stream/tests/test_stream_lsl.py b/mne_lsl/stream/tests/test_stream_lsl.py index 276676bce..e53f80d7e 100644 --- a/mne_lsl/stream/tests/test_stream_lsl.py +++ b/mne_lsl/stream/tests/test_stream_lsl.py @@ -721,35 +721,6 @@ def test_stream_filter_deletion(mock_lsl_stream, caplog): stream.disconnect() -@pytest.mark.skip(reason="Not yet implemented.") -def test_stream_filter_picks(mock_lsl_stream): - """Test picks from a StreamFilter.""" - stream = ( - Stream(bufsize=2.0, name=mock_lsl_stream.name) - .connect() - .filter(l_freq=1.0, h_freq=40.0, picks="eeg") - ) - assert len(stream.filters) == 1 - assert_allclose( - stream.filters[0]["picks"], - _picks_to_idx(mock_lsl_stream.info, picks="eeg", exclude=()), - ) - stream.pick(["F7", "F3", "Fz", "F4", "F8"]) # consecutive EEG-only channels - assert_allclose(stream.filters[0]["picks"], np.arange(5)) - stream.pick(["F3", "F4"]) # non-consecutive EEG-only channels - assert_allclose(stream.filters[0]["picks"], np.arange(2)) - stream.disconnect().connect() # reset - stream.filter(l_freq=None, h_freq=100.0, picks=("eeg", "ecg", "eog")) - assert len(stream.filters) == 1 - picks_ = _picks_to_idx( - mock_lsl_stream.info, picks=("eeg", "ecg", "eog"), exclude=() - ) - assert_allclose(stream.filters[0]["picks"], picks_) - stream.drop_channels(["ECG"]) # -2 channel - assert_allclose(stream.filters[0]["picks"], picks_[:-1]) - stream.disconnect() - - def test_stream_filter(mock_lsl_stream_sinusoids, raw_sinusoids): """Test stream filters.""" freqs = fftfreq(raw_sinusoids.times.size, 1 / raw_sinusoids.info["sfreq"])