Skip to content

Commit

Permalink
Add real-time decoding example (#312)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
mscheltienne and pre-commit-ci[bot] authored Aug 7, 2024
1 parent ede2439 commit ebb9fc9
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 5 deletions.
6 changes: 6 additions & 0 deletions doc/links.inc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
.. _lsl: https://labstreaminglayer.org
.. _lsl intro: https://labstreaminglayer.readthedocs.io/info/intro.html
.. _lsl language bindings: https://github.com/sccn/labstreaminglayer/tree/master/LSL
.. _lsl lib: https://github.com/sccn/liblsl
.. _lsl lib release: https://github.com/sccn/liblsl/releases
.. _lsl python: https://github.com/labstreaminglayer/pylsl
Expand All @@ -28,6 +29,11 @@
.. _mne installers: https://mne.tools/stable/install/installers.html


.. software
.. _sklearn stable: https://scikit-learn.org/stable/


.. project
.. _project pypi: https://pypi.org/project/mne-lsl/
Expand Down
168 changes: 168 additions & 0 deletions examples/40_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""
Decoding real-time data
=======================
.. include:: ./../../links.inc
This example demonstrates how to decode real-time data using `MNE-Python <mne stable_>`_
and `Scikit-learn <sklearn stable_>`_. We will stream the ``sample_audvis_raw.fif``
file from MNE's sample dataset with a :class:`~mne_lsl.player.PlayerLSL`, process the
signal through a :class:`~mne_lsl.stream.StreamLSL`, and decode the epochs created with
:class:`~mne_lsl.stream.EpochsStream`.
"""

import time
import uuid

import numpy as np
from matplotlib import pyplot as plt
from mne.decoding import Vectorizer
from mne.io import read_raw_fif
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.model_selection import ShuffleSplit, cross_val_score
from sklearn.preprocessing import StandardScaler

from mne_lsl.datasets import sample
from mne_lsl.player import PlayerLSL
from mne_lsl.stream import EpochsStream, StreamLSL

fname = sample.data_path() / "mne-sample" / "sample_audvis_raw.fif"
raw = read_raw_fif(fname, preload=False).pick(("meg", "stim")).load_data()
source_id = uuid.uuid4().hex
player = PlayerLSL(raw, chunk_size=200, name="real-time-decoding", source_id=source_id)
player.start()
player.info

# %%
# Signal processing
# -----------------
#
# We will apply minimal signal processing to the data. First, only the gradiometers will
# be used for decoding, thus other channels are removed. Then we mark bad channels and
# applying a low-pass filter at 40 Hz.

stream = StreamLSL(bufsize=5, name="real-time-decoding", source_id=source_id)
stream.connect(acquisition_delay=0.1, processing_flags="all")
stream.info["bads"] = ["MEG 2443"]
stream.pick(("grad", "stim")).filter(None, 40, picks="grad")
stream.info

# %%
# Epoch the signal
# ----------------
#
# Next, we will create epochs around the event ``1`` (audio left) and ``3`` (visual
# left).

epochs = EpochsStream(
stream,
bufsize=10,
event_id=dict(audio_left=1, visual_left=3),
event_channels="STI 014",
tmin=-0.2,
tmax=0.5,
baseline=(None, 0),
reject=dict(grad=4000e-13), # unit: T / m (gradiometers)
).connect(acquisition_delay=0.1)
epochs.info

# %%
# Define the classifier
# ---------------------
#
# We will use a :class:`~sklearn.linear_model.LogisticRegression` classifier to decode
# the epochs.
#
# .. note::
#
# The object :class:`~mne.decoding.Vectorizer` is used to transform the epochs in a
# 2D array of shape (n_epochs, n_features). It's simply reshapes the epochs data
# with:
#
# .. code-block:: python
#
# data = epochs.get_data()
# data = data.reshape(data.shape[0], -1)

vectorizer = Vectorizer()
scaler = StandardScaler()
clf = LogisticRegression()
classifier = Pipeline([("vector", vectorizer), ("scaler", scaler), ("svm", clf)])

# %%
# Decode
# ------
#
# First, we will wait for a minimum number of epochs to be available. Then, the
# classifier will be trained for the first time and future epochs will be used to
# retrain the classifier every 5 epochs.

min_epochs = 10
while epochs.n_new_epochs < min_epochs:
time.sleep(0.5)

# prepare figure to plot classifiation score
if not plt.isinteractive():
plt.ion()
fig, ax = plt.subplots()
ax.set_xlabel("Epochsn n°")
ax.set_ylabel("Classification score (% correct)")
ax.set_title("Real-time decoding")
ax.set_xlim([min_epochs, 50])
ax.set_ylim([30, 105])
ax.axhline(50, color="k", linestyle="--", label="Chance level")
plt.show()

# decoding loop
scores_x, scores, std_scores = [], [], []
while True:
if len(scores_x) != 0 and 50 <= scores_x[-1]:
break
n_epochs = epochs.n_new_epochs
if n_epochs == 0 or n_epochs % 5 != 0:
time.sleep(0.5) # give time to the streaming and acquisition threads
continue

if len(scores_x) == 0: # first training
X = epochs.get_data(n_epochs=n_epochs)
y = epochs.events[-n_epochs:]
else:
X = np.concatenate((X, epochs.get_data(n_epochs=n_epochs)), axis=0)
y = np.concatenate((y, epochs.events[-n_epochs:]))
cv = ShuffleSplit(5, test_size=0.2, random_state=42)
scores_t = cross_val_score(classifier, X, y, cv=cv, n_jobs=1) * 100
std_scores.append(scores_t.std())
scores.append(scores_t.mean())
scores_x.append(scores_x[-1] + n_epochs if len(scores_x) != 0 else n_epochs)

# update figure
ax.plot(scores_x[-2:], scores[-2:], "-x", color="b")
hyp_limits = (
np.asarray(scores[-2:]) - np.asarray(std_scores[-2:]),
np.asarray(scores[-2:]) + np.asarray(std_scores[-2:]),
)
fill = ax.fill_between(
scores_x[-2:], y1=hyp_limits[0], y2=hyp_limits[1], color="b", alpha=0.5
)
plt.pause(0.1)
plt.draw()

# %%
# Free resources
# --------------
#
# When you are done with a :class:`~mne_lsl.player.PlayerLSL`,
# :class:`~mne_lsl.stream.StreamLSL` ir :class:`~mne_lsl.stream.EpochsStream`, don't
# forget to free the resources they use to continuously mock an LSL stream or receive
# new data from an LSL stream.

epochs.disconnect()

# %%

stream.disconnect()

# %%

player.stop()
26 changes: 23 additions & 3 deletions mne_lsl/stream/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def connect(self, acquisition_delay: float = 0.001) -> EpochsStream:
),
dtype=self._stream._buffer.dtype,
)
self._buffer_events = np.zeros(self._bufsize, dtype=np.int16)
self._executor = (
ThreadPoolExecutor(max_workers=1) if self._acquisition_delay != 0 else None
)
Expand Down Expand Up @@ -504,12 +505,16 @@ def _acquire(self) -> None:
# select data, for loop is faster than the fancy indexing ideas tried and
# will anyway operate on a small number of events most of the time.
data_selection = np.empty(
(events.shape[0], self._buffer.shape[1], self._picks.size),
(
min(events.shape[0], self._bufsize),
self._buffer.shape[1],
self._picks.size,
),
dtype=data.dtype,
)
for k, start in enumerate(events[:, 0]):
for k, start in enumerate(events[:, 0][::-1]):
start += self._tmin_shift
data_selection[k] = data[
data_selection[-(k + 1)] = data[
self._picks, start : start + self._buffer.shape[1]
].T
# apply processing
Expand All @@ -530,6 +535,8 @@ def _acquire(self) -> None:
# roll buffer and add new epochs
self._buffer = np.roll(self._buffer, -events.shape[0], axis=0)
self._buffer[-events.shape[0] :, :, :] = data_selection
self._buffer_events = np.roll(self._buffer_events, -events.shape[0])
self._buffer_events[-events.shape[0] :] = events[:, 2]
# update the last ts and the number of new epochs
self._n_new_epochs += events.shape[0]
except Exception as error: # pragma: no cover
Expand All @@ -553,6 +560,7 @@ def _reset_variables(self):
"""Reset variables defined after connection."""
self._acquisition_delay = None
self._buffer = None
self._buffer_events = None
self._ch_idx_by_type = None
self._executor = None
self._info = None
Expand Down Expand Up @@ -592,6 +600,18 @@ def connected(self) -> bool:
assert not any(getattr(self, attr, None) is None for attr in attributes)
return True

@property
def events(self) -> NDArray[np.int16]:
"""Events of the epoched LSL stream.
Contrary to the events stored in ``mne.Epochs.events``, only the integer code
of the event is stored in a :class:`~mne_lsl.stream.EpochsStream` object.
:type: :class:`numpy.ndarray`
"""
self._check_connected("events")
return self._buffer_events

@property
def info(self) -> Info:
"""Info of the epoched LSL stream.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ doc = [
'memory-profiler',
'numpydoc',
'pyqt5',
'scikit-learn',
'sphinx!=7.2.*',
'sphinx-copybutton',
'sphinx-design',
Expand Down
2 changes: 0 additions & 2 deletions tutorials/10_low_level_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
.. include:: ./../../links.inc
.. _lsl language bindings: https://github.com/sccn/labstreaminglayer/tree/master/LSL
LSL is a library designed for streaming time series data across different platforms and
programming languages. The `core library <lsl lib_>`_ is primarily written in C++, and
bindings are accessible for Python, C#, Java, MATLAB, and Unity, among others. You can
Expand Down
14 changes: 14 additions & 0 deletions tutorials/40_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,20 @@
plt.show()

# %%
# Finally, in this case a single event was kept in the
# :class:`~mne_lsl.stream.EpochsStream`, but if more events are retained, it is
# important to know which one is which. This information is stored in the property
# :attr:`~mne_lsl.stream.EpochsStream.events` of the
# :class:`~mne_lsl.stream.EpochsStream`, which is an internal buffer of the event codes.

epochs.events

# %%
# .. note::
#
# In the case of an irregularly sampled event stream, the event code represents the
# channel idx within the event stream.
#
# Free resources
# --------------
#
Expand Down

0 comments on commit ebb9fc9

Please sign in to comment.