Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add filters to a stream object #218

Merged
merged 70 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
16f2b65
add arguments and description of the filtering method
mscheltienne Feb 27, 2024
667b343
fix picks/docstrings
mscheltienne Feb 27, 2024
ead65c3
add filter validation
mscheltienne Feb 27, 2024
cd736ec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 27, 2024
d29590a
fix logging in filter function
mscheltienne Feb 27, 2024
92bd919
Merge remote-tracking branch 'upstream/main' into filters
mscheltienne Feb 27, 2024
a2c409f
fix B006
mscheltienne Feb 27, 2024
3727e83
use create_filter directly and add filter storage
mscheltienne Feb 28, 2024
0aaa290
drop use_log_level
mscheltienne Feb 28, 2024
0682c83
edit filters during pick operation
mscheltienne Feb 28, 2024
b2a9ecd
add filter placeholder
mscheltienne Feb 28, 2024
213f526
add filtering of every buffer
mscheltienne Feb 28, 2024
aa765d6
fix data indexing
mscheltienne Feb 29, 2024
49e7d8e
fix variable name
mscheltienne Feb 29, 2024
fc52dd6
improve assertion and array selection
mscheltienne Feb 29, 2024
4dfff52
add notes about reset
mscheltienne Feb 29, 2024
b7745ee
add idea to sanitize filter list
mscheltienne Feb 29, 2024
d92dee1
add tests
mscheltienne Feb 29, 2024
3a04cdf
add more tests
mscheltienne Feb 29, 2024
c80c6f5
simpler comparison
mscheltienne Feb 29, 2024
518ea87
fix tests
mscheltienne Feb 29, 2024
ec960f7
fix tests
mscheltienne Feb 29, 2024
0d2a91d
more tests
mscheltienne Feb 29, 2024
21afacd
improve definition and handling of filters
mscheltienne Mar 1, 2024
827123f
better
mscheltienne Mar 1, 2024
585ed5d
fix typos
mscheltienne Mar 1, 2024
7f0d1d0
add test for (un)combination of filters
mscheltienne Mar 1, 2024
aafc5f4
rm bad sanitize_filters tests
mscheltienne Mar 1, 2024
9ff3e55
improve type-hints
mscheltienne Mar 1, 2024
10eef60
trigger cis
mscheltienne Mar 1, 2024
540fa45
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2024
26526bb
add tests and fix typos
mscheltienne Mar 1, 2024
07c01a9
fix style
mscheltienne Mar 1, 2024
6700fec
fix imports [ci skip]
mscheltienne Mar 1, 2024
eb5464f
add tests
mscheltienne Mar 1, 2024
5c659f9
more tests [ci skip]
mscheltienne Mar 1, 2024
cb1a2c5
more tests
mscheltienne Mar 1, 2024
b1d6dce
fix import
mscheltienne Mar 1, 2024
6f6aff0
better comparison
mscheltienne Mar 1, 2024
b2bbd1a
re-simplify
mscheltienne Mar 4, 2024
64dcd14
simplify stream code
mscheltienne Mar 4, 2024
1ccf2e1
add entries to docdict to de-duplicate docstrings
mscheltienne Mar 4, 2024
6c8f59c
fix tests for simplification
mscheltienne Mar 4, 2024
e343d8a
fix typos
mscheltienne Mar 4, 2024
13bf76b
add test for create_filter
mscheltienne Mar 4, 2024
2f4c701
add filters property and method to delete filters
mscheltienne Mar 4, 2024
c102df7
add test placeholder and sort idx for deletion
mscheltienne Mar 4, 2024
3bf8b47
add logs and test for deletion [skip ci]
mscheltienne Mar 4, 2024
32347a9
fix deletion test through logs
mscheltienne Mar 4, 2024
30f8d73
better
mscheltienne Mar 4, 2024
e9a9a16
fix x-ref to base class
mscheltienne Mar 6, 2024
a62b9f0
improve StreamFilter instantiation logic
mscheltienne Mar 6, 2024
d0be4c7
add tests
mscheltienne Mar 6, 2024
7a15822
add test for channel selection inc filters
mscheltienne Mar 6, 2024
5d1cac9
for now, prevent pick after filter
mscheltienne Mar 6, 2024
e8c8a7b
rm snippet
mscheltienne Mar 6, 2024
c90b749
better fixture names
mscheltienne Mar 6, 2024
d6c4f32
fix type hnits
mscheltienne Mar 6, 2024
f9918f9
filtersadd fixture to test
mscheltienne Mar 6, 2024
1da11eb
add test
mscheltienne Mar 6, 2024
5f0868e
add tests
mscheltienne Mar 7, 2024
9bfa931
fix typos
mscheltienne Mar 7, 2024
dc684c4
better rtol
mscheltienne Mar 7, 2024
b969bd3
add note
mscheltienne Mar 7, 2024
0669f40
rm line about filters in picks
mscheltienne Mar 7, 2024
addef55
better
mscheltienne Mar 7, 2024
2d22191
add dtype
mscheltienne Mar 7, 2024
b1a3a4f
fix
mscheltienne Mar 7, 2024
fce0ad8
fix docstrings
mscheltienne Mar 7, 2024
25c6ec9
rm test
mscheltienne Mar 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions mne_lsl/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
import numpy as np
from mne import Annotations
from mne import set_log_level as set_log_level_mne
from mne.io import Raw, read_raw_fif
from mne.io import read_raw_fif
from pytest import fixture

from mne_lsl import set_log_level
from mne_lsl import logger, set_log_level
from mne_lsl.datasets import testing
from mne_lsl.lsl import StreamInlet, StreamOutlet

if TYPE_CHECKING:
from pathlib import Path

from mne.io import BaseRaw
from pytest import Config

# Set debug logging in LSL, e.g.:
Expand Down Expand Up @@ -59,6 +60,7 @@ def pytest_configure(config: Config) -> None:
config.addinivalue_line("filterwarnings", warning_line)
set_log_level_mne("WARNING") # MNE logger
set_log_level("DEBUG") # MNE-lsl logger
logger.propagate = True


def pytest_sessionfinish(session, exitstatus) -> None:
Expand Down Expand Up @@ -111,13 +113,13 @@ def fname(tmp_path_factory) -> Path:


@fixture(scope="function")
def raw(fname: Path) -> Raw:
def raw(fname: Path) -> BaseRaw:
"""Return the raw file corresponding to fname."""
return read_raw_fif(fname, preload=True)


@fixture(scope="function")
def raw_annotations(raw: Raw) -> Raw:
def raw_annotations(raw: BaseRaw) -> BaseRaw:
"""Return a raw file with annotations."""
annotations = Annotations(
onset=[0.1, 0.4, 0.5, 0.8, 0.95, 1.1, 1.3],
Expand Down
180 changes: 164 additions & 16 deletions mne_lsl/stream/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from math import ceil
from threading import Timer
from typing import TYPE_CHECKING
from warnings import warn

import numpy as np
from mne import pick_info, pick_types
from mne.channels import rename_channels
from mne.utils import check_version, use_log_level
from mne.utils import check_version

if check_version("mne", "1.6"):
from mne._fiff.constants import FIFF, _ch_unit_mul_named
Expand All @@ -25,14 +26,15 @@
from mne.io.pick import _picks_to_idx
from mne.channels.channels import SetChannelsMixin

from ..utils._checks import check_type, check_value
from ..utils._checks import check_type, check_value, ensure_int
from ..utils._docs import copy_doc, fill_doc
from ..utils.logs import logger, verbose
from ..utils.meas_info import _HUMAN_UNITS, _set_channel_units
from ._filters import StreamFilter, create_filter

if TYPE_CHECKING:
from datetime import datetime
from typing import Callable, Optional, Union
from typing import Any, Callable, Optional, Union

from mne import Info
from mne.channels import DigMontage
Expand Down Expand Up @@ -306,6 +308,83 @@ def disconnect(self) -> BaseStream:
# This method needs to close any inlet/network object and need to end with
# self._reset_variables().

def del_filter(self, idx: Union[int, list[int], tuple[int], str] = "all") -> None:
"""Remove a filter from the list of applied filters.

Parameters
----------
idx : ``'all'``| int | list of int | tuple of int
If the string ``'all'`` (default), remove all filters. If an integer or a
list of integers, remove the filter(s) at the given index(es) from
``Stream.filters``.

Notes
-----
When removing a filter, the initial conditions of all the filters applied on
overlapping channels are reset. The initial conditions will be re-estimated as
a step response steady-state.
"""
self._check_connected_and_regular_sampling("del_filter()")
if len(self._filters) == 0:
raise RuntimeError("No filter to remove.")
# validate input
check_type(idx, ("int-like", tuple, list, str), "idx")
if isinstance(idx, str) and idx != "all":
raise ValueError(
"If 'idx' is provided as str, it must be 'all', which will remove all "
"applied filters. Provided '{idx}' is invalid."
)
elif idx == "all":
idx = np.arange(len(self._filters), dtype=np.uint8)
elif isinstance(idx, (tuple, list)):
for elt in idx:
check_type(elt, ("int-like",), "idx")
idx = np.array(idx, dtype=np.uint8)
else:
# ensure_int is run as a sanity-check, it should not be possible to enter
# this statement without idx as int-like.
idx = np.array([ensure_int(idx, "idx")], dtype=np.uint8)
if not all(0 <= k < len(self._filters) for k in idx):
raise ValueError(
"The index 'idx' must be a positive integer or a list of positive "
"integers not exceeding the number of filters minus 1: "
f"{len(self._filters) - 1}."
)
idx_unique = np.unique(idx)
if idx_unique.size != idx.size:
warn(
"The index 'idx' contains duplicates. Only unique indices will be "
"used.",
RuntimeWarning,
stacklevel=2,
)
idx = np.sort(idx_unique)
logger.info(
"Removing filters at index(es): %s\n%s",
", ".join([str(k) for k in idx]),
"\n".join([repr(self._filters[k]) for k in idx]),
)
# figure out which filter have overlapping channels and will need their initial
# conditions to be reset to a step response steady-state.
picks = np.unique(np.hstack([self._filters[k]["picks"] for k in idx]))
filters2reset = list()
for k, filt in enumerate(self._filters):
if k in idx:
continue # this filter will be deleted
if np.intersect1d(filt["picks"], picks).size != 0:
filters2reset.append(k)
if len(filters2reset) != 0:
logger.info(
"The initial conditions will be reset on filters:\n%s",
"\n".join([repr(self._filters[k]) for k in filters2reset]),
)
# interrupt acquisition and apply changes
with self._interrupt_acquisition():
for k in filters2reset:
self._filters[k]["zi"] = None
for k in idx[::-1]:
del self._filters[k]

def drop_channels(self, ch_names: Union[str, list[str], tuple[str]]) -> BaseStream:
"""Drop channel(s).

Expand Down Expand Up @@ -338,16 +417,76 @@ def drop_channels(self, ch_names: Union[str, list[str], tuple[str]]) -> BaseStre
self._pick(picks)
return self

def filter(self) -> BaseStream: # noqa: A003
"""Filter the stream. Not implemented.
@verbose
@fill_doc
def filter(
self,
l_freq: Optional[float],
h_freq: Optional[float],
picks,
iir_params: Optional[dict[str, Any]] = None,
*,
verbose: Optional[Union[bool, str, int]] = None,
) -> BaseStream: # noqa: A003
"""Filter the stream with an IIR causal filter.

Once a filter is applied, the buffer is updated in real-time with the filtered
data. It is possible to apply more than one filter.

.. code-block:: python

stream = Stream(2.0).connect()
stream.filter(1.0, 40.0, picks="eeg")
stream.filter(1.0, 15.0, picks="ecg").filter(0.1, 5, picks="EDA")

Parameters
----------
%(l_freq)s
%(h_freq)s
%(picks_all)s
%(iir_params)s
%(verbose)s

Returns
-------
stream : instance of ``Stream``
The stream instance modified in-place.
"""
self._check_connected_and_regular_sampling("filter()")
raise NotImplementedError
# validate the arguments and ensure 'sos' output
picks = _picks_to_idx(self._info, picks, "all", "bads", allow_empty=False)
iir_params = (
dict(order=4, ftype="butter", output="sos")
if iir_params is None
else iir_params
)
check_type(iir_params, (dict,), "iir_params")
if ("output" in iir_params and iir_params["output"] != "sos") or all(
key in iir_params for key in ("a", "b")
):
warn(
"Only 'sos' output is supported for real-time filtering. The filter "
"output will be automatically changed. Please set "
"iir_params=dict(output='sos', ...) in your call to filter().",
RuntimeWarning,
stacklevel=2,
)
for key in ("a", "b"):
if key in iir_params:
del iir_params[key]
iir_params["output"] = "sos"
# construct an IIR filter
filt = create_filter(
sfreq=self._info["sfreq"],
l_freq=l_freq,
h_freq=h_freq,
iir_params=iir_params,
)
filt.update(picks=picks) # channel selection
# add filter to the list of applied filters
with self._interrupt_acquisition():
self._filters.append(StreamFilter(filt))
return self

@copy_doc(ContainsMixin.get_channel_types)
def get_channel_types(
Expand Down Expand Up @@ -381,7 +520,7 @@ def get_channel_units(
self._check_connected(name="get_channel_units()")
check_type(only_data_chs, (bool,), "only_data_chs")
none = "data" if only_data_chs else "all"
picks = _picks_to_idx(self._info, picks, none, (), allow_empty=False)
picks = _picks_to_idx(self._info, picks, none, "bads", allow_empty=False)
channel_units = list()
for idx in picks:
channel_units.append(
Expand Down Expand Up @@ -442,7 +581,7 @@ def get_data(
# 8.68 µs ± 113 ns per loop
# >>> %timeit _picks_to_idx(raw.info, None)
# 253 µs ± 1.22 µs per loop
picks = _picks_to_idx(self._info, picks, none="all")
picks = _picks_to_idx(self._info, picks, none="all", exclude="bads")
self._n_new_samples = 0 # reset the number of new samples
return self._buffer[-n_samples:, picks].T, self._timestamps[-n_samples:]
except Exception:
Expand Down Expand Up @@ -475,7 +614,7 @@ def pick(self, picks, exclude=()) -> BaseStream:

Parameters
----------
%(picks_all)s
%(picks_base)s all channels.
exclude : str | list of str
Set of channels to exclude, only used when picking is based on types, e.g.
``exclude='bads'`` when ``picks="meg"``.
Expand Down Expand Up @@ -837,20 +976,21 @@ def _pick(self, picks: NDArray[+ScalarIntType]) -> None:
"The channel selection must be done before adding a re-refenrecing "
"schema with Stream.set_eeg_reference()."
)

picks_inlet = picks[np.where(picks < self._picks_inlet.size)[0]]
if picks_inlet.size == 0:
raise RuntimeError(
"The requested channel selection would not leave any channel from the "
"LSL Stream."
"Stream."
)
if len(self._filters) != 0:
raise RuntimeError(
"The channel selection must be done before adding filters to the "
"Stream."
)

with self._interrupt_acquisition():
with use_log_level(logger.level):
self._info = pick_info(self._info, picks)
self._info = pick_info(self._info, picks, verbose=logger.level)
self._picks_inlet = self._picks_inlet[picks_inlet]
self._buffer = self._buffer[:, picks]

# prune added channels which are not part of the inlet
for ch in self._added_channels[::-1]:
if ch not in self.ch_names:
Expand All @@ -869,6 +1009,7 @@ def _reset_variables(self) -> None:
self._added_channels = []
self._ref_channels = None
self._ref_from = None
self._filters = []
self._timestamps = None
# This method needs to reset any stream-system-specific variables, e.g. an inlet
# or a StreamInfo for LSL streams.
Expand All @@ -883,7 +1024,6 @@ def compensation_grade(self) -> Optional[int]:
self._check_connected(name="compensation_grade")
return super().compensation_grade

# ----------------------------------------------------------------------------------
@property
def ch_names(self) -> list[str]:
"""Name of the channels.
Expand Down Expand Up @@ -919,6 +1059,14 @@ def dtype(self) -> Optional[DTypeLike]:
"""Channel format of the stream."""
return getattr(self._buffer, "dtype", None)

@property
def filters(self) -> list[StreamFilter]:
"""List of filters applied to the real-time Stream.

:type: :class:`list` of ```StreamFilter``
"""
return self._filters

@property
def info(self) -> Info:
"""Info of the LSL stream.
Expand Down
Loading
Loading