-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH, MRG] Add EpochsSpectrumArray and SpectrumArray classes #11803
Changes from all commits
395a758
710cef4
e34bc38
3fcb557
597c48c
6a54f68
125903e
e0fb07f
8cee941
738c8d3
0cd6ec8
5d18ff3
27be8d8
145b6bb
71016d6
1ccf8a4
ece2b39
ff203ff
725446d
ecbd0a8
f8e98ce
cb80688
fd830e9
e5bd6a9
e1bcbd9
4a76c9f
7ef0626
f031c8b
62e25f4
f8e74c9
36df9b5
147d110
ba24ebb
026e8e6
7340ea2
e4ca8ca
4d48f42
b98d2b8
d8fefc8
3c4da24
9c2cd25
10a2667
92d8a9d
5a33651
6404aaa
da45d5c
1a42bf6
e71672a
04262f2
b099177
265f009
2427b9f
54488ed
1a1e57f
9f45402
8f49d05
831c1ee
b6707f3
e178231
88c790e
f735e12
40627ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -297,6 +297,12 @@ def raw_ctf(): | |
return raw_ctf | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def raw_spectrum(raw): | ||
"""Get raw with power spectral density computed from mne.io.tests.data.""" | ||
return raw.compute_psd() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we have a good reason to test both Welch and multitaper in the fixture; the differences aren't relevant to most tests (and should be covered adequately by the tests of the welch/mutitaper array methods used under the hood) |
||
|
||
|
||
@pytest.fixture(scope="function") | ||
def events(): | ||
"""Get events from mne.io.tests.data.""" | ||
|
@@ -349,6 +355,12 @@ def epochs_full(): | |
return _get_epochs(None).load_data() | ||
|
||
|
||
@pytest.fixture() | ||
def epochs_spectrum(): | ||
"""Get epochs with power spectral density computed from mne.io.tests.data.""" | ||
return _get_epochs().load_data().compute_psd() | ||
|
||
|
||
@pytest.fixture() | ||
def epochs_empty(): | ||
"""Get empty epochs from mne.io.tests.data.""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -397,8 +397,10 @@ def __setstate__(self, state): | |
self._data_type = state["data_type"] | ||
self.preload = True | ||
# instance type | ||
inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked) | ||
inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Array=np.ndarray) | ||
self._inst_type = inst_types[state["inst_type_str"]] | ||
if "weights" in state and state["weights"] is not None: | ||
self._mt_weights = state["weights"] | ||
|
||
def __repr__(self): | ||
"""Build string representation of the Spectrum object.""" | ||
|
@@ -486,6 +488,8 @@ def _get_instance_type_string(self): | |
inst_type_str = "Epochs" | ||
elif self._inst_type in (Evoked, EvokedArray): | ||
inst_type_str = "Evoked" | ||
elif self._inst_type is np.ndarray: | ||
inst_type_str = "Array" | ||
else: | ||
raise RuntimeError(f"Unknown instance type {self._inst_type} in Spectrum") | ||
return inst_type_str | ||
|
@@ -766,6 +770,8 @@ def plot_topo( | |
layout = find_layout(self.info) | ||
|
||
psds, freqs = self.get_data(return_freqs=True) | ||
if "epoch" in self._dims: | ||
psds = np.mean(psds, axis=self._dims.index("epoch")) | ||
if dB: | ||
psds = 10 * np.log10(psds) | ||
y_label = "dB" | ||
|
@@ -977,7 +983,7 @@ def to_data_frame( | |
# check pandas once here, instead of in each private utils function | ||
pd = _check_pandas_installed() # noqa | ||
# triage for Epoch-derived or unaggregated spectra | ||
from_epo = self._get_instance_type_string() == "Epochs" | ||
from_epo = self._dims[0] == "epoch" | ||
unagg_welch = "segment" in self._dims | ||
unagg_mt = "taper" in self._dims | ||
# arg checking | ||
|
@@ -1089,6 +1095,7 @@ class Spectrum(BaseSpectrum): | |
See Also | ||
-------- | ||
EpochsSpectrum | ||
SpectrumArray | ||
mne.io.Raw.compute_psd | ||
mne.Epochs.compute_psd | ||
mne.Evoked.compute_psd | ||
|
@@ -1190,6 +1197,75 @@ def __getitem__(self, item): | |
return BaseRaw._getitem(self, item, return_times=False) | ||
|
||
|
||
def _check_data_shape(data, freqs, info, ndim): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I reworked this to check everything at once rather than needing multiple calls to the check function. Seemed simpler this way once we knew we only had to handle 2D and 3D cases. |
||
if data.ndim != ndim: | ||
raise ValueError(f"Data must be a {ndim}D array.") | ||
want_n_chan = _pick_data_channels(info).size | ||
want_n_freq = freqs.size | ||
got_n_chan, got_n_freq = data.shape[-2:] | ||
if got_n_chan != want_n_chan: | ||
raise ValueError( | ||
f"The number of channels in `data` ({got_n_chan}) must match the " | ||
f"number of good data channels in `info` ({want_n_chan})." | ||
) | ||
if got_n_freq != want_n_freq: | ||
raise ValueError( | ||
f"The last dimension of `data` ({got_n_freq}) must have the same " | ||
f"number of elements as `freqs` ({want_n_freq})." | ||
) | ||
|
||
|
||
@fill_doc | ||
class SpectrumArray(Spectrum): | ||
"""Data object for precomputed spectral data (in NumPy array format). | ||
|
||
Parameters | ||
---------- | ||
data : array, shape (n_channels, n_freqs) | ||
The power spectral density for each channel. | ||
%(info_not_none)s | ||
%(freqs_tfr)s | ||
drammock marked this conversation as resolved.
Show resolved
Hide resolved
|
||
%(verbose)s | ||
|
||
See Also | ||
-------- | ||
mne.create_info | ||
mne.EvokedArray | ||
mne.io.RawArray | ||
EpochsSpectrumArray | ||
|
||
Notes | ||
----- | ||
%(notes_spectrum_array)s | ||
|
||
.. versionadded:: 1.6 | ||
""" | ||
|
||
@verbose | ||
def __init__( | ||
self, | ||
data, | ||
info, | ||
freqs, | ||
*, | ||
verbose=None, | ||
): | ||
_check_data_shape(data, freqs, info, ndim=2) | ||
|
||
self.__setstate__( | ||
dict( | ||
method="unknown", | ||
data=data, | ||
sfreq=info["sfreq"], | ||
dims=("channel", "freq"), | ||
freqs=freqs, | ||
inst_type_str="Array", | ||
data_type="Power Spectrum", | ||
info=info, | ||
) | ||
) | ||
|
||
|
||
@fill_doc | ||
class EpochsSpectrum(BaseSpectrum, GetEpochsMixin): | ||
"""Data object for spectral representations of epoched data. | ||
|
@@ -1225,10 +1301,9 @@ class EpochsSpectrum(BaseSpectrum, GetEpochsMixin): | |
|
||
See Also | ||
-------- | ||
EpochsSpectrumArray | ||
Spectrum | ||
mne.io.Raw.compute_psd | ||
mne.Epochs.compute_psd | ||
mne.Evoked.compute_psd | ||
|
||
References | ||
---------- | ||
|
@@ -1385,6 +1460,70 @@ def average(self, method="mean"): | |
return Spectrum(state, **defaults) | ||
|
||
|
||
@fill_doc | ||
class EpochsSpectrumArray(EpochsSpectrum): | ||
"""Data object for precomputed epoched spectral data (in NumPy array format). | ||
|
||
Parameters | ||
---------- | ||
data : array, shape (n_epochs, n_channels, n_freqs) | ||
The power spectral density for each channel in each epoch. | ||
%(info_not_none)s | ||
%(freqs_tfr)s | ||
%(events_epochs)s | ||
%(event_id)s | ||
%(verbose)s | ||
|
||
See Also | ||
-------- | ||
mne.create_info | ||
mne.EpochsArray | ||
SpectrumArray | ||
|
||
Notes | ||
----- | ||
%(notes_spectrum_array)s | ||
|
||
.. versionadded:: 1.6 | ||
""" | ||
|
||
@verbose | ||
def __init__( | ||
self, | ||
data, | ||
info, | ||
freqs, | ||
events=None, | ||
event_id=None, | ||
*, | ||
verbose=None, | ||
): | ||
_check_data_shape(data, freqs, info, ndim=3) | ||
if events is not None and data.shape[0] != events.shape[0]: | ||
raise ValueError( | ||
f"The first dimension of `data` ({data.shape[0]}) must match the " | ||
f"first dimension of `events` ({events.shape[0]})." | ||
) | ||
|
||
self.__setstate__( | ||
dict( | ||
method="unknown", | ||
data=data, | ||
sfreq=info["sfreq"], | ||
dims=("epoch", "channel", "freq"), | ||
freqs=freqs, | ||
inst_type_str="Array", | ||
data_type="Power Spectrum", | ||
info=info, | ||
events=events, | ||
event_id=event_id, | ||
metadata=None, | ||
selection=np.arange(data.shape[0]), | ||
drop_log=tuple(tuple() for _ in range(data.shape[0])), | ||
) | ||
) | ||
|
||
|
||
def read_spectrum(fname): | ||
"""Load a :class:`mne.time_frequency.Spectrum` object from disk. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for naming consistency