diff --git a/doc/changes/devel.rst b/doc/changes/devel.rst index ab0311c0233..1556884a518 100644 --- a/doc/changes/devel.rst +++ b/doc/changes/devel.rst @@ -30,6 +30,7 @@ Enhancements - Added option ``remove_dc`` to to :meth:`Raw.compute_psd() `, :meth:`Epochs.compute_psd() `, and :meth:`Evoked.compute_psd() `, to allow skipping DC removal when computing Welch or multitaper spectra (:gh:`11769` by `Nikolai Chapochnikov`_) - Add the possibility to provide a float between 0 and 1 as ``n_grad``, ``n_mag`` and ``n_eeg`` in `~mne.compute_proj_raw`, `~mne.compute_proj_epochs` and `~mne.compute_proj_evoked` to select the number of vectors based on the cumulative explained variance (:gh:`11919` by `Mathieu Scheltienne`_) - Add helpful error messages when using methods on empty :class:`mne.Epochs`-objects (:gh:`11306` by `Martin Schulz`_) +- Add :class:`~mne.time_frequency.EpochsSpectrumArray` and :class:`~mne.time_frequency.SpectrumArray` to support creating power spectra from :class:`NumPy array ` data (:gh:`11803` by `Alex Rockhill`_) Bugs ~~~~ diff --git a/doc/time_frequency.rst b/doc/time_frequency.rst index a366dbdecb9..63601691414 100644 --- a/doc/time_frequency.rst +++ b/doc/time_frequency.rst @@ -17,7 +17,9 @@ Time-Frequency EpochsTFR CrossSpectralDensity Spectrum + SpectrumArray EpochsSpectrum + EpochsSpectrumArray Functions that operate on mne-python objects: diff --git a/mne/conftest.py b/mne/conftest.py index f6f985c41da..1cfe6f021d6 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -297,6 +297,12 @@ def raw_ctf(): return raw_ctf +@pytest.fixture(scope="function") +def raw_spectrum(raw): + """Get raw with power spectral density computed from mne.io.tests.data.""" + return raw.compute_psd() + + @pytest.fixture(scope="function") def events(): """Get events from mne.io.tests.data.""" @@ -349,6 +355,12 @@ def epochs_full(): return _get_epochs(None).load_data() +@pytest.fixture() +def epochs_spectrum(): + """Get epochs with power spectral density computed from mne.io.tests.data.""" + return _get_epochs().load_data().compute_psd() + + @pytest.fixture() def epochs_empty(): """Get empty epochs from mne.io.tests.data.""" diff --git a/mne/epochs.py b/mne/epochs.py index 0182488a876..5e5d93b78e4 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -3197,40 +3197,17 @@ class EpochsArray(BaseEpochs): measure. %(info_not_none)s Consider using :func:`mne.create_info` to populate this structure. - events : None | array of int, shape (n_events, 3) - The events typically returned by the read_events function. - If some events don't match the events of interest as specified - by event_id, they will be marked as 'IGNORED' in the drop log. - If None (default), all event values are set to 1 and event time-samples - are set to range(n_epochs). - tmin : float - Start time before event. If nothing provided, defaults to 0. - event_id : int | list of int | dict | None - The id of the event to consider. If dict, - the keys can later be used to access associated events. Example: - dict(auditory=1, visual=3). If int, a dict will be created with - the id as string. If a list, all events with the IDs specified - in the list are used. If None, all events will be used with - and a dict is created with string integer names corresponding - to the event id integers. + %(events_epochs)s + %(tmin_epochs)s + %(event_id)s %(reject_epochs)s %(flat)s - reject_tmin : scalar | None - Start of the time window used to reject epochs (with the default None, - the window will start with tmin). - reject_tmax : scalar | None - End of the time window used to reject epochs (with the default None, - the window will end with tmax). + %(epochs_reject_tmin_tmax)s %(baseline_epochs)s Defaults to ``None``, i.e. no baseline correction. - proj : bool | 'delayed' - Apply SSP projection vectors. See :class:`mne.Epochs` for details. - on_missing : str - See :class:`mne.Epochs` docstring for details. - metadata : instance of pandas.DataFrame | None - See :class:`mne.Epochs` docstring for details. - - .. versionadded:: 0.16 + %(proj_epochs)s + %(on_missing_epochs)s + %(metadata_epochs)s %(selection)s %(drop_log)s diff --git a/mne/time_frequency/__init__.py b/mne/time_frequency/__init__.py index b4c586b83f7..8f245bee7f6 100644 --- a/mne/time_frequency/__init__.py +++ b/mne/time_frequency/__init__.py @@ -34,7 +34,13 @@ "tfr_array_multitaper", ], "psd": ["psd_array_welch"], - "spectrum": ["EpochsSpectrum", "Spectrum", "read_spectrum"], + "spectrum": [ + "EpochsSpectrum", + "EpochsSpectrumArray", + "Spectrum", + "SpectrumArray", + "read_spectrum", + ], "tfr": [ "_BaseTFR", "AverageTFR", diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index 299b4aee984..a67a3dbc3c4 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -397,8 +397,10 @@ def __setstate__(self, state): self._data_type = state["data_type"] self.preload = True # instance type - inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked) + inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Array=np.ndarray) self._inst_type = inst_types[state["inst_type_str"]] + if "weights" in state and state["weights"] is not None: + self._mt_weights = state["weights"] def __repr__(self): """Build string representation of the Spectrum object.""" @@ -486,6 +488,8 @@ def _get_instance_type_string(self): inst_type_str = "Epochs" elif self._inst_type in (Evoked, EvokedArray): inst_type_str = "Evoked" + elif self._inst_type is np.ndarray: + inst_type_str = "Array" else: raise RuntimeError(f"Unknown instance type {self._inst_type} in Spectrum") return inst_type_str @@ -766,6 +770,8 @@ def plot_topo( layout = find_layout(self.info) psds, freqs = self.get_data(return_freqs=True) + if "epoch" in self._dims: + psds = np.mean(psds, axis=self._dims.index("epoch")) if dB: psds = 10 * np.log10(psds) y_label = "dB" @@ -977,7 +983,7 @@ def to_data_frame( # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa # triage for Epoch-derived or unaggregated spectra - from_epo = self._get_instance_type_string() == "Epochs" + from_epo = self._dims[0] == "epoch" unagg_welch = "segment" in self._dims unagg_mt = "taper" in self._dims # arg checking @@ -1089,6 +1095,7 @@ class Spectrum(BaseSpectrum): See Also -------- EpochsSpectrum + SpectrumArray mne.io.Raw.compute_psd mne.Epochs.compute_psd mne.Evoked.compute_psd @@ -1190,6 +1197,75 @@ def __getitem__(self, item): return BaseRaw._getitem(self, item, return_times=False) +def _check_data_shape(data, freqs, info, ndim): + if data.ndim != ndim: + raise ValueError(f"Data must be a {ndim}D array.") + want_n_chan = _pick_data_channels(info).size + want_n_freq = freqs.size + got_n_chan, got_n_freq = data.shape[-2:] + if got_n_chan != want_n_chan: + raise ValueError( + f"The number of channels in `data` ({got_n_chan}) must match the " + f"number of good data channels in `info` ({want_n_chan})." + ) + if got_n_freq != want_n_freq: + raise ValueError( + f"The last dimension of `data` ({got_n_freq}) must have the same " + f"number of elements as `freqs` ({want_n_freq})." + ) + + +@fill_doc +class SpectrumArray(Spectrum): + """Data object for precomputed spectral data (in NumPy array format). + + Parameters + ---------- + data : array, shape (n_channels, n_freqs) + The power spectral density for each channel. + %(info_not_none)s + %(freqs_tfr)s + %(verbose)s + + See Also + -------- + mne.create_info + mne.EvokedArray + mne.io.RawArray + EpochsSpectrumArray + + Notes + ----- + %(notes_spectrum_array)s + + .. versionadded:: 1.6 + """ + + @verbose + def __init__( + self, + data, + info, + freqs, + *, + verbose=None, + ): + _check_data_shape(data, freqs, info, ndim=2) + + self.__setstate__( + dict( + method="unknown", + data=data, + sfreq=info["sfreq"], + dims=("channel", "freq"), + freqs=freqs, + inst_type_str="Array", + data_type="Power Spectrum", + info=info, + ) + ) + + @fill_doc class EpochsSpectrum(BaseSpectrum, GetEpochsMixin): """Data object for spectral representations of epoched data. @@ -1225,10 +1301,9 @@ class EpochsSpectrum(BaseSpectrum, GetEpochsMixin): See Also -------- + EpochsSpectrumArray Spectrum - mne.io.Raw.compute_psd mne.Epochs.compute_psd - mne.Evoked.compute_psd References ---------- @@ -1385,6 +1460,70 @@ def average(self, method="mean"): return Spectrum(state, **defaults) +@fill_doc +class EpochsSpectrumArray(EpochsSpectrum): + """Data object for precomputed epoched spectral data (in NumPy array format). + + Parameters + ---------- + data : array, shape (n_epochs, n_channels, n_freqs) + The power spectral density for each channel in each epoch. + %(info_not_none)s + %(freqs_tfr)s + %(events_epochs)s + %(event_id)s + %(verbose)s + + See Also + -------- + mne.create_info + mne.EpochsArray + SpectrumArray + + Notes + ----- + %(notes_spectrum_array)s + + .. versionadded:: 1.6 + """ + + @verbose + def __init__( + self, + data, + info, + freqs, + events=None, + event_id=None, + *, + verbose=None, + ): + _check_data_shape(data, freqs, info, ndim=3) + if events is not None and data.shape[0] != events.shape[0]: + raise ValueError( + f"The first dimension of `data` ({data.shape[0]}) must match the " + f"first dimension of `events` ({events.shape[0]})." + ) + + self.__setstate__( + dict( + method="unknown", + data=data, + sfreq=info["sfreq"], + dims=("epoch", "channel", "freq"), + freqs=freqs, + inst_type_str="Array", + data_type="Power Spectrum", + info=info, + events=events, + event_id=event_id, + metadata=None, + selection=np.arange(data.shape[0]), + drop_log=tuple(tuple() for _ in range(data.shape[0])), + ) + ) + + def read_spectrum(fname): """Load a :class:`mne.time_frequency.Spectrum` object from disk. diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index 45f208395c3..fcbb561faed 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -4,12 +4,14 @@ import numpy as np import pytest from numpy.testing import assert_array_equal, assert_allclose +import matplotlib.pyplot as plt from mne import create_info, make_fixed_length_epochs from mne.io import RawArray from mne import Annotations from mne.time_frequency import read_spectrum from mne.time_frequency.multitaper import _psd_from_mt +from mne.time_frequency.spectrum import SpectrumArray, EpochsSpectrumArray def test_spectrum_errors(raw): @@ -140,18 +142,21 @@ def test_spectrum_io(inst, tmp_path, request, evoked): assert orig == loaded -def test_spectrum_copy(raw): +def test_spectrum_copy(raw_spectrum): """Test copying Spectrum objects.""" - spect = raw.compute_psd() - spect_copy = spect.copy() - assert spect == spect_copy - assert id(spect) != id(spect_copy) + spect_copy = raw_spectrum.copy() + assert raw_spectrum == spect_copy + assert id(raw_spectrum) != id(spect_copy) spect_copy._freqs = None - assert spect.freqs is not None + assert raw_spectrum.freqs is not None def test_spectrum_reject_by_annot(raw): - """Test rejecting by annotation.""" + """Test rejecting by annotation. + + Cannot use raw_spectrum fixture here because we're testing reject_by_annotation in + .compute_psd() method. + """ spect_no_annot = raw.compute_psd() raw.set_annotations(Annotations([1, 5], [3, 3], ["test", "test"])) spect_benign_annot = raw.compute_psd() @@ -164,30 +169,27 @@ def test_spectrum_reject_by_annot(raw): assert spect_no_annot != spect_reject_annot -def test_spectrum_getitem_raw(raw): +def test_spectrum_getitem_raw(raw_spectrum): """Test Spectrum.__getitem__ for Raw-derived spectra.""" - spect = raw.compute_psd() - want = spect.get_data(slice(1, 3), fmax=7) - freq_idx = np.searchsorted(spect.freqs, 7) - got = spect[1:3, :freq_idx] + want = raw_spectrum.get_data(slice(1, 3), fmax=7) + freq_idx = np.searchsorted(raw_spectrum.freqs, 7) + got = raw_spectrum[1:3, :freq_idx] assert_array_equal(want, got) -def test_spectrum_getitem_epochs(epochs): +def test_spectrum_getitem_epochs(epochs_spectrum): """Test Spectrum.__getitem__ for Epochs-derived spectra.""" - spect = epochs.compute_psd() # testing data has just one epoch, its event_id label is "1" - want = spect.get_data() - got = spect["1"].get_data() + want = epochs_spectrum.get_data() + got = epochs_spectrum["1"].get_data() assert_array_equal(want, got) @pytest.mark.parametrize("method", ("mean", partial(np.std, axis=0))) -def test_epochs_spectrum_average(epochs, method): +def test_epochs_spectrum_average(epochs_spectrum, method): """Test EpochsSpectrum.average().""" - spect = epochs.compute_psd() - avg_spect = spect.average(method=method) - assert avg_spect.shape == spect.shape[1:] + avg_spect = epochs_spectrum.average(method=method) + assert avg_spect.shape == epochs_spectrum.shape[1:] assert avg_spect._dims == ("channel", "freq") # no 'epoch' @@ -262,19 +264,20 @@ def _fun(x): assert_frame_equal(agg_df, orig_df, check_categorical=False) -@pytest.mark.parametrize("inst", ("raw", "epochs", "evoked")) +@pytest.mark.parametrize("inst", ("raw_spectrum", "epochs_spectrum", "evoked")) def test_spectrum_to_data_frame(inst, request, evoked): """Test the to_data_frame method for Spectrum.""" pytest.importorskip("pandas") from pandas.testing import assert_frame_equal # setup - is_epochs = inst == "epochs" + is_already_psd = inst in ("raw_spectrum", "epochs_spectrum") + is_epochs = inst == "epochs_spectrum" inst = _get_inst(inst, request, evoked) extra_dim = () if is_epochs else (1,) extra_cols = ["freq", "condition", "epoch"] if is_epochs else ["freq"] # compute PSD - spectrum = inst.compute_psd() + spectrum = inst if is_already_psd else inst.compute_psd() n_epo, n_chan, n_freq = extra_dim + spectrum.get_data().shape # test wide format df_wide = spectrum.to_data_frame() @@ -343,9 +346,9 @@ def test_spectrum_complex(method, average): assert len(epochs) == 5 assert len(epochs.times) == 2 * sfreq kwargs = dict(output="complex", method=method) + ctx = pytest.warns(UserWarning, match="Zero value") if method == "welch": kwargs["n_fft"] = sfreq - ctx = pytest.warns(UserWarning, match="Zero value") want_dims = ("epoch", "channel", "freq") want_shape = (5, 1, sfreq // 2 + 1) if not average: @@ -386,3 +389,55 @@ def test_spectrum_kwarg_triaging(raw): raw.plot_psd(axes=axes) # `ax` is the correct legacy param name raw.plot_psd(ax=axes) + + +def _check_spectrum_equivalent(spect1, spect2, tmp_path): + data1 = spect1.get_data() + data2 = spect2.get_data() + assert_array_equal(data1, data2) + assert_array_equal(spect1.freqs, spect2.freqs) + + +def test_spectrum_array_errors(epochs_spectrum): + """Test EpochsSpectrumArray constructor errors.""" + data, freqs = epochs_spectrum.get_data(return_freqs=True) + info = epochs_spectrum.info + with pytest.raises(ValueError, match="Data must be a 3D array"): + EpochsSpectrumArray(np.empty((2, 3, 4, 5)), info, freqs) + with pytest.raises(ValueError, match=r"number of channels.*good data channels"): + EpochsSpectrumArray(data[:, :-1], info, freqs) + with pytest.raises(ValueError, match=r"last dimension.*same number of elements"): + EpochsSpectrumArray(data[..., :-1], info, freqs) + # test mismatching events shape + n_epo = data.shape[0] + 1 # +1 so they purposely don't match + events = np.vstack( + (np.arange(n_epo), np.zeros(n_epo, dtype=int), np.ones(n_epo, dtype=int)) + ).T + with pytest.raises(ValueError, match=r"first dimension.*dimension of `events`"): + EpochsSpectrumArray(data, info, freqs, events) + + +@pytest.mark.parametrize("kind", ("raw", "epochs")) +def test_spectrum_array(kind, tmp_path, request): + """Test EpochsSpectrumArray and SpectrumArray constructors.""" + spectrum = request.getfixturevalue(f"{kind}_spectrum") + data, freqs = spectrum.get_data(return_freqs=True) + Klass = SpectrumArray if kind == "raw" else EpochsSpectrumArray + spect_arr = Klass(data=data, info=spectrum.info, freqs=freqs) + _check_spectrum_equivalent(spectrum, spect_arr, tmp_path) + + +@pytest.mark.parametrize("kind", ("raw", "epochs")) +@pytest.mark.parametrize("array", (False, True)) +def test_plot_spectrum(kind, array, request): + """Test plotting (Epochs)Spectrum(Array).""" + spectrum = request.getfixturevalue(f"{kind}_spectrum") + if array: + data, freqs = spectrum.get_data(return_freqs=True) + Klass = SpectrumArray if kind == "raw" else EpochsSpectrumArray + spectrum = Klass(data=data, info=spectrum.info, freqs=freqs) + spectrum.plot(average=True, amplitude=True, spatial_colors=True) + spectrum.plot(average=False, amplitude=False, spatial_colors=False) + spectrum.plot_topo() + spectrum.plot_topomap() + plt.close("all") diff --git a/mne/utils/docs.py b/mne/utils/docs.py index f20e785fc34..3aac253d8a0 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -2817,6 +2817,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["notes_plot_*_psd_func"] = _notes_plot_psd.format("function") docdict["notes_plot_psd_meth"] = _notes_plot_psd.format("method") +docdict[ + "notes_spectrum_array" +] = """ +It is assumed that the data passed in represent spectral *power* (not amplitude, +phase, model coefficients, etc) and downstream methods (such as +:meth:`~mne.time_frequency.SpectrumArray.plot`) assume power data. If you pass in +something other than power, at the very least axis labels will be inaccurate (and +other things may also not work or be incorrect). +""" + docdict[ "notes_tmax_included_by_default" ] = """ @@ -4568,6 +4578,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Time point of the first sample in data. """ +docdict[ + "tmin_epochs" +] = """ +tmin : float + Start time before event. If nothing provided, defaults to 0. +""" + docdict[ "tmin_raw" ] = """ diff --git a/tutorials/simulation/10_array_objs.py b/tutorials/simulation/10_array_objs.py index 5678a94fee2..a7fc8c88985 100644 --- a/tutorials/simulation/10_array_objs.py +++ b/tutorials/simulation/10_array_objs.py @@ -212,9 +212,8 @@ # decomposition for estimation of power spectra. Or you may wish to # process pre-computed power spectra in MNE. # Following the same logic, it is possible to instantiate averaged power -# spectrum using the :class:`~mne.time_frequency.Spectrum` class. -# This is slightly -# experimental at the moment but works. An API for doing this may follow. +# spectrum using the :class:`~mne.time_frequency.SpectrumArray` or +# :class:`~mne.time_frequency.EpochsSpectrumArray` classes. # compute power spectrum @@ -224,40 +223,11 @@ psd_ave = psd.mean(0) -# map to `~mne.time_frequency.Spectrum` class and explore API - - -def spectrum_from_array( - data: np.ndarray, # spectral features - freqs: np.ndarray, # frequencies - inst_info: mne.Info, # the meta data of MNE instance -) -> mne.time_frequency.Spectrum: # Spectrum object - """Create MNE averaged power spectrum object from custom data""" - state = dict( - method="my_welch", - data=data, - sfreq=inst_info["sfreq"], - dims=("channel", "freq"), - freqs=freqs, - inst_type_str="Raw", - data_type="Averaged Power Spectrum", - info=inst_info, - ) - defaults = dict( - method=None, - fmin=None, - fmax=None, - tmin=None, - tmax=None, - picks=None, - proj=None, - remove_dc=None, - reject_by_annotation=None, - n_jobs=None, - verbose=None, - ) - return mne.time_frequency.Spectrum(state, **defaults) - +info = mne.create_info(["Ch 1", "Ch2"], sfreq=sampling_freq, ch_types="eeg") +spectrum = mne.time_frequency.SpectrumArray( + data=psd_ave, + freqs=freqs, + info=info, +) -spectrum = spectrum_from_array(data=psd_ave, freqs=freqs, inst_info=info) -spectrum.plot(picks=[0, 1], spatial_colors=False, exclude="bads") +spectrum.plot(spatial_colors=False)