Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH, MRG] Add EpochsSpectrumArray and SpectrumArray classes #11803

Merged
merged 62 commits into from
Sep 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
395a758
[ENH, MRG] Add EpochsSpectrumArray and SpectrumArray classes [skip ci
alexrockhill Jul 13, 2023
710cef4
update latest
alexrockhill Jul 13, 2023
e34bc38
fix refs
alexrockhill Jul 13, 2023
3fcb557
wrong versionadded
alexrockhill Jul 13, 2023
597c48c
Update mne/time_frequency/spectrum.py
alexrockhill Jul 14, 2023
6a54f68
Update mne/time_frequency/spectrum.py
alexrockhill Jul 14, 2023
125903e
epoch not epochs
alexrockhill Jul 14, 2023
e0fb07f
edit seealso [ci skip]
drammock Jul 14, 2023
8cee941
Dan review
alexrockhill Jul 14, 2023
738c8d3
Merge branch 'defaults' of https://github.com/alexrockhill/mne-python…
alexrockhill Jul 14, 2023
0cd6ec8
style
alexrockhill Jul 14, 2023
5d18ff3
Merge branch 'main' into defaults
alexrockhill Jul 14, 2023
27be8d8
cruft, test plot
alexrockhill Jul 14, 2023
145b6bb
style
alexrockhill Jul 14, 2023
71016d6
very picky style
alexrockhill Jul 14, 2023
1ccf8a4
add fixtures
alexrockhill Jul 18, 2023
ece2b39
Merge branch 'main' into defaults
alexrockhill Jul 18, 2023
ff203ff
Merge branch 'main' into defaults
alexrockhill Jul 19, 2023
725446d
Dan review
alexrockhill Jul 26, 2023
ecbd0a8
Merge branch 'main' of https://github.com/mne-tools/mne-python into d…
alexrockhill Jul 26, 2023
f8e98ce
Merge branch 'defaults' of https://github.com/alexrockhill/mne-python…
alexrockhill Jul 26, 2023
cb80688
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2023
fd830e9
style
alexrockhill Jul 26, 2023
e5bd6a9
Merge branch 'defaults' of https://github.com/alexrockhill/mne-python…
alexrockhill Jul 26, 2023
e1bcbd9
fix repr raising error
drammock Jul 26, 2023
4a76c9f
make repr more honest
drammock Jul 26, 2023
7ef0626
readability
drammock Jul 26, 2023
f031c8b
fix wrong docstring
drammock Jul 26, 2023
62e25f4
make docstrings parallel
drammock Jul 26, 2023
f8e74c9
add note about assuming power
drammock Jul 26, 2023
36df9b5
through version with bug fixes, plots were checked
alexrockhill Jul 27, 2023
147d110
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 27, 2023
ba24ebb
style'
alexrockhill Jul 27, 2023
026e8e6
Merge branch 'defaults' of https://github.com/alexrockhill/mne-python…
alexrockhill Jul 27, 2023
7340ea2
Merge branch 'main' into defaults
alexrockhill Jul 27, 2023
e4ca8ca
Merge branch 'defaults' of https://github.com/alexrockhill/mne-python…
alexrockhill Jul 27, 2023
4d48f42
style
alexrockhill Jul 27, 2023
b98d2b8
fix tests
alexrockhill Jul 28, 2023
d8fefc8
fix one last tests
alexrockhill Jul 28, 2023
3c4da24
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2023
9c2cd25
style
alexrockhill Jul 28, 2023
10a2667
Merge branch 'main' into defaults
alexrockhill Aug 1, 2023
92d8a9d
spelling
alexrockhill Aug 2, 2023
5a33651
Merge branch 'defaults' of https://github.com/alexrockhill/mne-python…
alexrockhill Aug 2, 2023
6404aaa
Merge branch 'main' of https://github.com/mne-tools/mne-python into d…
alexrockhill Aug 2, 2023
da45d5c
refactor tests
alexrockhill Aug 2, 2023
1a42bf6
oops switched psd, psd2
alexrockhill Aug 3, 2023
e71672a
style
alexrockhill Aug 3, 2023
04262f2
Merge branch 'main' of https://github.com/mne-tools/mne-python into d…
alexrockhill Aug 18, 2023
b099177
resolve conflicts
alexrockhill Aug 18, 2023
265f009
cruft
alexrockhill Aug 18, 2023
2427b9f
Merge branch 'main' into defaults
alexrockhill Aug 25, 2023
54488ed
try skip h5io
alexrockhill Aug 25, 2023
1a1e57f
Merge branch 'defaults' of https://github.com/alexrockhill/mne-python…
alexrockhill Aug 25, 2023
9f45402
Merge branch 'main' into defaults
alexrockhill Aug 29, 2023
8f49d05
remove complex support
alexrockhill Aug 29, 2023
831c1ee
Merge branch 'defaults' of https://github.com/alexrockhill/mne-python…
alexrockhill Aug 29, 2023
b6707f3
Merge branch 'main' into defaults
alexrockhill Aug 29, 2023
e178231
simplifications and fixes
drammock Sep 1, 2023
88c790e
Merge remote-tracking branch 'upstream/main' into defaults
drammock Sep 1, 2023
f735e12
oops missed this
drammock Sep 1, 2023
40627ac
Update tutorials/simulation/10_array_objs.py
drammock Sep 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Enhancements
- Added option ``remove_dc`` to to :meth:`Raw.compute_psd() <mne.io.Raw.compute_psd>`, :meth:`Epochs.compute_psd() <mne.Epochs.compute_psd>`, and :meth:`Evoked.compute_psd() <mne.Evoked.compute_psd>`, to allow skipping DC removal when computing Welch or multitaper spectra (:gh:`11769` by `Nikolai Chapochnikov`_)
- Add the possibility to provide a float between 0 and 1 as ``n_grad``, ``n_mag`` and ``n_eeg`` in `~mne.compute_proj_raw`, `~mne.compute_proj_epochs` and `~mne.compute_proj_evoked` to select the number of vectors based on the cumulative explained variance (:gh:`11919` by `Mathieu Scheltienne`_)
- Add helpful error messages when using methods on empty :class:`mne.Epochs`-objects (:gh:`11306` by `Martin Schulz`_)
- Add :class:`~mne.time_frequency.EpochsSpectrumArray` and :class:`~mne.time_frequency.SpectrumArray` to support creating power spectra from :class:`NumPy array <numpy.ndarray>` data (:gh:`11803` by `Alex Rockhill`_)

Bugs
~~~~
Expand Down
2 changes: 2 additions & 0 deletions doc/time_frequency.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ Time-Frequency
EpochsTFR
CrossSpectralDensity
Spectrum
SpectrumArray
EpochsSpectrum
EpochsSpectrumArray

Functions that operate on mne-python objects:

Expand Down
12 changes: 12 additions & 0 deletions mne/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@ def raw_ctf():
return raw_ctf


@pytest.fixture(scope="function")
def raw_spectrum(raw):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for naming consistency

"""Get raw with power spectral density computed from mne.io.tests.data."""
return raw.compute_psd()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have a good reason to test both Welch and multitaper in the fixture; the differences aren't relevant to most tests (and should be covered adequately by the tests of the welch/mutitaper array methods used under the hood)



@pytest.fixture(scope="function")
def events():
"""Get events from mne.io.tests.data."""
Expand Down Expand Up @@ -349,6 +355,12 @@ def epochs_full():
return _get_epochs(None).load_data()


@pytest.fixture()
def epochs_spectrum():
"""Get epochs with power spectral density computed from mne.io.tests.data."""
return _get_epochs().load_data().compute_psd()


@pytest.fixture()
def epochs_empty():
"""Get empty epochs from mne.io.tests.data."""
Expand Down
37 changes: 7 additions & 30 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3197,40 +3197,17 @@ class EpochsArray(BaseEpochs):
measure.
%(info_not_none)s Consider using :func:`mne.create_info` to populate this
structure.
events : None | array of int, shape (n_events, 3)
The events typically returned by the read_events function.
If some events don't match the events of interest as specified
by event_id, they will be marked as 'IGNORED' in the drop log.
If None (default), all event values are set to 1 and event time-samples
are set to range(n_epochs).
tmin : float
Start time before event. If nothing provided, defaults to 0.
event_id : int | list of int | dict | None
The id of the event to consider. If dict,
the keys can later be used to access associated events. Example:
dict(auditory=1, visual=3). If int, a dict will be created with
the id as string. If a list, all events with the IDs specified
in the list are used. If None, all events will be used with
and a dict is created with string integer names corresponding
to the event id integers.
%(events_epochs)s
%(tmin_epochs)s
%(event_id)s
%(reject_epochs)s
%(flat)s
reject_tmin : scalar | None
Start of the time window used to reject epochs (with the default None,
the window will start with tmin).
reject_tmax : scalar | None
End of the time window used to reject epochs (with the default None,
the window will end with tmax).
%(epochs_reject_tmin_tmax)s
%(baseline_epochs)s
Defaults to ``None``, i.e. no baseline correction.
proj : bool | 'delayed'
Apply SSP projection vectors. See :class:`mne.Epochs` for details.
on_missing : str
See :class:`mne.Epochs` docstring for details.
metadata : instance of pandas.DataFrame | None
See :class:`mne.Epochs` docstring for details.

.. versionadded:: 0.16
%(proj_epochs)s
%(on_missing_epochs)s
%(metadata_epochs)s
%(selection)s
%(drop_log)s

Expand Down
8 changes: 7 additions & 1 deletion mne/time_frequency/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@
"tfr_array_multitaper",
],
"psd": ["psd_array_welch"],
"spectrum": ["EpochsSpectrum", "Spectrum", "read_spectrum"],
"spectrum": [
"EpochsSpectrum",
"EpochsSpectrumArray",
"Spectrum",
"SpectrumArray",
"read_spectrum",
],
"tfr": [
"_BaseTFR",
"AverageTFR",
Expand Down
147 changes: 143 additions & 4 deletions mne/time_frequency/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,10 @@ def __setstate__(self, state):
self._data_type = state["data_type"]
self.preload = True
# instance type
inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked)
inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Array=np.ndarray)
self._inst_type = inst_types[state["inst_type_str"]]
if "weights" in state and state["weights"] is not None:
self._mt_weights = state["weights"]

def __repr__(self):
"""Build string representation of the Spectrum object."""
Expand Down Expand Up @@ -486,6 +488,8 @@ def _get_instance_type_string(self):
inst_type_str = "Epochs"
elif self._inst_type in (Evoked, EvokedArray):
inst_type_str = "Evoked"
elif self._inst_type is np.ndarray:
inst_type_str = "Array"
else:
raise RuntimeError(f"Unknown instance type {self._inst_type} in Spectrum")
return inst_type_str
Expand Down Expand Up @@ -766,6 +770,8 @@ def plot_topo(
layout = find_layout(self.info)

psds, freqs = self.get_data(return_freqs=True)
if "epoch" in self._dims:
psds = np.mean(psds, axis=self._dims.index("epoch"))
if dB:
psds = 10 * np.log10(psds)
y_label = "dB"
Expand Down Expand Up @@ -977,7 +983,7 @@ def to_data_frame(
# check pandas once here, instead of in each private utils function
pd = _check_pandas_installed() # noqa
# triage for Epoch-derived or unaggregated spectra
from_epo = self._get_instance_type_string() == "Epochs"
from_epo = self._dims[0] == "epoch"
unagg_welch = "segment" in self._dims
unagg_mt = "taper" in self._dims
# arg checking
Expand Down Expand Up @@ -1089,6 +1095,7 @@ class Spectrum(BaseSpectrum):
See Also
--------
EpochsSpectrum
SpectrumArray
mne.io.Raw.compute_psd
mne.Epochs.compute_psd
mne.Evoked.compute_psd
Expand Down Expand Up @@ -1190,6 +1197,75 @@ def __getitem__(self, item):
return BaseRaw._getitem(self, item, return_times=False)


def _check_data_shape(data, freqs, info, ndim):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reworked this to check everything at once rather than needing multiple calls to the check function. Seemed simpler this way once we knew we only had to handle 2D and 3D cases.

if data.ndim != ndim:
raise ValueError(f"Data must be a {ndim}D array.")
want_n_chan = _pick_data_channels(info).size
want_n_freq = freqs.size
got_n_chan, got_n_freq = data.shape[-2:]
if got_n_chan != want_n_chan:
raise ValueError(
f"The number of channels in `data` ({got_n_chan}) must match the "
f"number of good data channels in `info` ({want_n_chan})."
)
if got_n_freq != want_n_freq:
raise ValueError(
f"The last dimension of `data` ({got_n_freq}) must have the same "
f"number of elements as `freqs` ({want_n_freq})."
)


@fill_doc
class SpectrumArray(Spectrum):
"""Data object for precomputed spectral data (in NumPy array format).

Parameters
----------
data : array, shape (n_channels, n_freqs)
The power spectral density for each channel.
%(info_not_none)s
%(freqs_tfr)s
drammock marked this conversation as resolved.
Show resolved Hide resolved
%(verbose)s

See Also
--------
mne.create_info
mne.EvokedArray
mne.io.RawArray
EpochsSpectrumArray

Notes
-----
%(notes_spectrum_array)s

.. versionadded:: 1.6
"""

@verbose
def __init__(
self,
data,
info,
freqs,
*,
verbose=None,
):
_check_data_shape(data, freqs, info, ndim=2)

self.__setstate__(
dict(
method="unknown",
data=data,
sfreq=info["sfreq"],
dims=("channel", "freq"),
freqs=freqs,
inst_type_str="Array",
data_type="Power Spectrum",
info=info,
)
)


@fill_doc
class EpochsSpectrum(BaseSpectrum, GetEpochsMixin):
"""Data object for spectral representations of epoched data.
Expand Down Expand Up @@ -1225,10 +1301,9 @@ class EpochsSpectrum(BaseSpectrum, GetEpochsMixin):

See Also
--------
EpochsSpectrumArray
Spectrum
mne.io.Raw.compute_psd
mne.Epochs.compute_psd
mne.Evoked.compute_psd

References
----------
Expand Down Expand Up @@ -1385,6 +1460,70 @@ def average(self, method="mean"):
return Spectrum(state, **defaults)


@fill_doc
class EpochsSpectrumArray(EpochsSpectrum):
"""Data object for precomputed epoched spectral data (in NumPy array format).

Parameters
----------
data : array, shape (n_epochs, n_channels, n_freqs)
The power spectral density for each channel in each epoch.
%(info_not_none)s
%(freqs_tfr)s
%(events_epochs)s
%(event_id)s
%(verbose)s

See Also
--------
mne.create_info
mne.EpochsArray
SpectrumArray

Notes
-----
%(notes_spectrum_array)s

.. versionadded:: 1.6
"""

@verbose
def __init__(
self,
data,
info,
freqs,
events=None,
event_id=None,
*,
verbose=None,
):
_check_data_shape(data, freqs, info, ndim=3)
if events is not None and data.shape[0] != events.shape[0]:
raise ValueError(
f"The first dimension of `data` ({data.shape[0]}) must match the "
f"first dimension of `events` ({events.shape[0]})."
)

self.__setstate__(
dict(
method="unknown",
data=data,
sfreq=info["sfreq"],
dims=("epoch", "channel", "freq"),
freqs=freqs,
inst_type_str="Array",
data_type="Power Spectrum",
info=info,
events=events,
event_id=event_id,
metadata=None,
selection=np.arange(data.shape[0]),
drop_log=tuple(tuple() for _ in range(data.shape[0])),
)
)


def read_spectrum(fname):
"""Load a :class:`mne.time_frequency.Spectrum` object from disk.

Expand Down
Loading
Loading