From ebb9fc92ef8a6d268b84c5f98a7b9305cbb40868 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 7 Aug 2024 12:59:44 +0200 Subject: [PATCH] Add real-time decoding example (#312) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/links.inc | 6 ++ examples/40_decode.py | 168 ++++++++++++++++++++++++++++++++++ mne_lsl/stream/epochs.py | 26 +++++- pyproject.toml | 1 + tutorials/10_low_level_API.py | 2 - tutorials/40_epochs.py | 14 +++ 6 files changed, 212 insertions(+), 5 deletions(-) create mode 100644 examples/40_decode.py diff --git a/doc/links.inc b/doc/links.inc index 4f6acfe3a..c10822eeb 100644 --- a/doc/links.inc +++ b/doc/links.inc @@ -17,6 +17,7 @@ .. _lsl: https://labstreaminglayer.org .. _lsl intro: https://labstreaminglayer.readthedocs.io/info/intro.html +.. _lsl language bindings: https://github.com/sccn/labstreaminglayer/tree/master/LSL .. _lsl lib: https://github.com/sccn/liblsl .. _lsl lib release: https://github.com/sccn/liblsl/releases .. _lsl python: https://github.com/labstreaminglayer/pylsl @@ -28,6 +29,11 @@ .. _mne installers: https://mne.tools/stable/install/installers.html +.. software + +.. _sklearn stable: https://scikit-learn.org/stable/ + + .. project .. _project pypi: https://pypi.org/project/mne-lsl/ diff --git a/examples/40_decode.py b/examples/40_decode.py new file mode 100644 index 000000000..9e39540db --- /dev/null +++ b/examples/40_decode.py @@ -0,0 +1,168 @@ +""" +Decoding real-time data +======================= + +.. include:: ./../../links.inc + +This example demonstrates how to decode real-time data using `MNE-Python `_ +and `Scikit-learn `_. We will stream the ``sample_audvis_raw.fif`` +file from MNE's sample dataset with a :class:`~mne_lsl.player.PlayerLSL`, process the +signal through a :class:`~mne_lsl.stream.StreamLSL`, and decode the epochs created with +:class:`~mne_lsl.stream.EpochsStream`. +""" + +import time +import uuid + +import numpy as np +from matplotlib import pyplot as plt +from mne.decoding import Vectorizer +from mne.io import read_raw_fif +from sklearn.linear_model import LogisticRegression +from sklearn.pipeline import Pipeline +from sklearn.model_selection import ShuffleSplit, cross_val_score +from sklearn.preprocessing import StandardScaler + +from mne_lsl.datasets import sample +from mne_lsl.player import PlayerLSL +from mne_lsl.stream import EpochsStream, StreamLSL + +fname = sample.data_path() / "mne-sample" / "sample_audvis_raw.fif" +raw = read_raw_fif(fname, preload=False).pick(("meg", "stim")).load_data() +source_id = uuid.uuid4().hex +player = PlayerLSL(raw, chunk_size=200, name="real-time-decoding", source_id=source_id) +player.start() +player.info + +# %% +# Signal processing +# ----------------- +# +# We will apply minimal signal processing to the data. First, only the gradiometers will +# be used for decoding, thus other channels are removed. Then we mark bad channels and +# applying a low-pass filter at 40 Hz. + +stream = StreamLSL(bufsize=5, name="real-time-decoding", source_id=source_id) +stream.connect(acquisition_delay=0.1, processing_flags="all") +stream.info["bads"] = ["MEG 2443"] +stream.pick(("grad", "stim")).filter(None, 40, picks="grad") +stream.info + +# %% +# Epoch the signal +# ---------------- +# +# Next, we will create epochs around the event ``1`` (audio left) and ``3`` (visual +# left). + +epochs = EpochsStream( + stream, + bufsize=10, + event_id=dict(audio_left=1, visual_left=3), + event_channels="STI 014", + tmin=-0.2, + tmax=0.5, + baseline=(None, 0), + reject=dict(grad=4000e-13), # unit: T / m (gradiometers) +).connect(acquisition_delay=0.1) +epochs.info + +# %% +# Define the classifier +# --------------------- +# +# We will use a :class:`~sklearn.linear_model.LogisticRegression` classifier to decode +# the epochs. +# +# .. note:: +# +# The object :class:`~mne.decoding.Vectorizer` is used to transform the epochs in a +# 2D array of shape (n_epochs, n_features). It's simply reshapes the epochs data +# with: +# +# .. code-block:: python +# +# data = epochs.get_data() +# data = data.reshape(data.shape[0], -1) + +vectorizer = Vectorizer() +scaler = StandardScaler() +clf = LogisticRegression() +classifier = Pipeline([("vector", vectorizer), ("scaler", scaler), ("svm", clf)]) + +# %% +# Decode +# ------ +# +# First, we will wait for a minimum number of epochs to be available. Then, the +# classifier will be trained for the first time and future epochs will be used to +# retrain the classifier every 5 epochs. + +min_epochs = 10 +while epochs.n_new_epochs < min_epochs: + time.sleep(0.5) + +# prepare figure to plot classifiation score +if not plt.isinteractive(): + plt.ion() +fig, ax = plt.subplots() +ax.set_xlabel("Epochsn n°") +ax.set_ylabel("Classification score (% correct)") +ax.set_title("Real-time decoding") +ax.set_xlim([min_epochs, 50]) +ax.set_ylim([30, 105]) +ax.axhline(50, color="k", linestyle="--", label="Chance level") +plt.show() + +# decoding loop +scores_x, scores, std_scores = [], [], [] +while True: + if len(scores_x) != 0 and 50 <= scores_x[-1]: + break + n_epochs = epochs.n_new_epochs + if n_epochs == 0 or n_epochs % 5 != 0: + time.sleep(0.5) # give time to the streaming and acquisition threads + continue + + if len(scores_x) == 0: # first training + X = epochs.get_data(n_epochs=n_epochs) + y = epochs.events[-n_epochs:] + else: + X = np.concatenate((X, epochs.get_data(n_epochs=n_epochs)), axis=0) + y = np.concatenate((y, epochs.events[-n_epochs:])) + cv = ShuffleSplit(5, test_size=0.2, random_state=42) + scores_t = cross_val_score(classifier, X, y, cv=cv, n_jobs=1) * 100 + std_scores.append(scores_t.std()) + scores.append(scores_t.mean()) + scores_x.append(scores_x[-1] + n_epochs if len(scores_x) != 0 else n_epochs) + + # update figure + ax.plot(scores_x[-2:], scores[-2:], "-x", color="b") + hyp_limits = ( + np.asarray(scores[-2:]) - np.asarray(std_scores[-2:]), + np.asarray(scores[-2:]) + np.asarray(std_scores[-2:]), + ) + fill = ax.fill_between( + scores_x[-2:], y1=hyp_limits[0], y2=hyp_limits[1], color="b", alpha=0.5 + ) + plt.pause(0.1) + plt.draw() + +# %% +# Free resources +# -------------- +# +# When you are done with a :class:`~mne_lsl.player.PlayerLSL`, +# :class:`~mne_lsl.stream.StreamLSL` ir :class:`~mne_lsl.stream.EpochsStream`, don't +# forget to free the resources they use to continuously mock an LSL stream or receive +# new data from an LSL stream. + +epochs.disconnect() + +# %% + +stream.disconnect() + +# %% + +player.stop() diff --git a/mne_lsl/stream/epochs.py b/mne_lsl/stream/epochs.py index d9347bcaf..d104a2ad6 100644 --- a/mne_lsl/stream/epochs.py +++ b/mne_lsl/stream/epochs.py @@ -324,6 +324,7 @@ def connect(self, acquisition_delay: float = 0.001) -> EpochsStream: ), dtype=self._stream._buffer.dtype, ) + self._buffer_events = np.zeros(self._bufsize, dtype=np.int16) self._executor = ( ThreadPoolExecutor(max_workers=1) if self._acquisition_delay != 0 else None ) @@ -504,12 +505,16 @@ def _acquire(self) -> None: # select data, for loop is faster than the fancy indexing ideas tried and # will anyway operate on a small number of events most of the time. data_selection = np.empty( - (events.shape[0], self._buffer.shape[1], self._picks.size), + ( + min(events.shape[0], self._bufsize), + self._buffer.shape[1], + self._picks.size, + ), dtype=data.dtype, ) - for k, start in enumerate(events[:, 0]): + for k, start in enumerate(events[:, 0][::-1]): start += self._tmin_shift - data_selection[k] = data[ + data_selection[-(k + 1)] = data[ self._picks, start : start + self._buffer.shape[1] ].T # apply processing @@ -530,6 +535,8 @@ def _acquire(self) -> None: # roll buffer and add new epochs self._buffer = np.roll(self._buffer, -events.shape[0], axis=0) self._buffer[-events.shape[0] :, :, :] = data_selection + self._buffer_events = np.roll(self._buffer_events, -events.shape[0]) + self._buffer_events[-events.shape[0] :] = events[:, 2] # update the last ts and the number of new epochs self._n_new_epochs += events.shape[0] except Exception as error: # pragma: no cover @@ -553,6 +560,7 @@ def _reset_variables(self): """Reset variables defined after connection.""" self._acquisition_delay = None self._buffer = None + self._buffer_events = None self._ch_idx_by_type = None self._executor = None self._info = None @@ -592,6 +600,18 @@ def connected(self) -> bool: assert not any(getattr(self, attr, None) is None for attr in attributes) return True + @property + def events(self) -> NDArray[np.int16]: + """Events of the epoched LSL stream. + + Contrary to the events stored in ``mne.Epochs.events``, only the integer code + of the event is stored in a :class:`~mne_lsl.stream.EpochsStream` object. + + :type: :class:`numpy.ndarray` + """ + self._check_connected("events") + return self._buffer_events + @property def info(self) -> Info: """Info of the epoched LSL stream. diff --git a/pyproject.toml b/pyproject.toml index 39d4d3c49..cd9c90657 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ doc = [ 'memory-profiler', 'numpydoc', 'pyqt5', + 'scikit-learn', 'sphinx!=7.2.*', 'sphinx-copybutton', 'sphinx-design', diff --git a/tutorials/10_low_level_API.py b/tutorials/10_low_level_API.py index 1b0db0f9c..aaf7a67fa 100644 --- a/tutorials/10_low_level_API.py +++ b/tutorials/10_low_level_API.py @@ -6,8 +6,6 @@ .. include:: ./../../links.inc -.. _lsl language bindings: https://github.com/sccn/labstreaminglayer/tree/master/LSL - LSL is a library designed for streaming time series data across different platforms and programming languages. The `core library `_ is primarily written in C++, and bindings are accessible for Python, C#, Java, MATLAB, and Unity, among others. You can diff --git a/tutorials/40_epochs.py b/tutorials/40_epochs.py index bee31da9d..48f1e70ba 100644 --- a/tutorials/40_epochs.py +++ b/tutorials/40_epochs.py @@ -262,6 +262,20 @@ plt.show() # %% +# Finally, in this case a single event was kept in the +# :class:`~mne_lsl.stream.EpochsStream`, but if more events are retained, it is +# important to know which one is which. This information is stored in the property +# :attr:`~mne_lsl.stream.EpochsStream.events` of the +# :class:`~mne_lsl.stream.EpochsStream`, which is an internal buffer of the event codes. + +epochs.events + +# %% +# .. note:: +# +# In the case of an irregularly sampled event stream, the event code represents the +# channel idx within the event stream. +# # Free resources # -------------- #