Skip to content

Commit

Permalink
Update from review
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Jul 30, 2024
1 parent 85ef415 commit 01f26c3
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 29 deletions.
51 changes: 24 additions & 27 deletions mne/time_frequency/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,23 +1214,22 @@ def __getitem__(self, item):
return BaseRaw._getitem(self, item, return_times=False)


def _check_data_shape(data, info, freqs, dimnames, weights):
def _check_data_shape(data, info, freqs, dimnames, weights, is_epoched):
if data.ndim != len(dimnames):
raise ValueError(
f"Expected data to have {len(dimnames)} dimensions, got {data.ndim}."
)

is_epoched = 1 if "epoch" in dimnames else 0
allowed_dims = ["epoch", "channel", "freq", "segment", "taper"]
allowed_dims = allowed_dims[0 if is_epoched else 1 :]
if set(allowed_dims).intersection(dimnames) != set(dimnames):
raise ValueError(
f"All entries of `dimnames` must be in {allowed_dims}, got {dimnames}."
)
if not is_epoched:
allowed_dims.remove("epoch")
# TODO maybe we should be nice and allow plural versions of each dimname?
for dim in dimnames:
_check_option("dimnames", dim, allowed_dims)
if "channel" not in dimnames or "freq" not in dimnames:
raise ValueError("Both 'channel' and 'freq' must be present in `dimnames`.")

if list(dimnames).index("channel") != is_epoched:
if list(dimnames).index("channel") != int(is_epoched):
raise ValueError(
f"'channel' must be the {'second' if is_epoched else 'first'} dimension of "
"the data."
Expand All @@ -1255,8 +1254,8 @@ def _check_data_shape(data, info, freqs, dimnames, weights):
expected = data.shape[list(dimnames).index("taper")]
if actual != expected:
raise ValueError(
f"Expected size of `weights` to be {expected} to match `data`, got "
f"{actual}."
f"Expected size of `weights` to be {expected} to match 'n_tapers' in "
f"`data`, got {actual}."
)
elif "segment" in dimnames and dimnames[-1] != "segment":
raise ValueError("'segment' must be the last dimension of the data.")
Expand All @@ -1282,11 +1281,13 @@ class SpectrumArray(Spectrum):
%(info_not_none)s
%(freqs_tfr_array)s
dimnames : tuple of str
The name of the dimensions in the data. Must contain ``'channel'`` and
``'freq'``. Can also contain one of ``'taper'`` or ``'segment'``.
The name of the dimensions in the data, in the order they occur. Must contain
``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include
either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g.,
multitaper algorithms) dimension. If including ``'taper'``, you should also pass
a ``weights`` parameter.
weights : ndarray | None
The multitaper weights used for averaging across tapers. Only required if data
from :func:`~mne.time_frequency.psd_array_multitaper` with ``output='complex'``.
Weights for the ``'taper'`` dimension, if present (see ``dimnames``).
%(verbose)s
See Also
Expand Down Expand Up @@ -1317,13 +1318,7 @@ def __init__(
# (channel, [taper], freq, [segment])
_check_option("data.ndim", data.ndim, (2, 3)) # only allow one extra dimension

if "epoch" in dimnames:
raise ValueError(
"'`data` must not be epoched. Use mne.time_frequency."
"EpochsSpectrumArray for storing epoched spectral data."
)

_check_data_shape(data, info, freqs, dimnames, weights)
_check_data_shape(data, info, freqs, dimnames, weights, is_epoched=False)

self.__setstate__(
dict(
Expand Down Expand Up @@ -1560,11 +1555,13 @@ class EpochsSpectrumArray(EpochsSpectrum):
%(events_epochs)s
%(event_id)s
dimnames : tuple of str
The name of the dimensions in the data. Must contain ``'epoch'``, ``'channel'``,
and ``'freq'``. Can also contain one of ``'taper'`` or ``'segment'``.
The name of the dimensions in the data, in the order they occur. Must contain
``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include
either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g.,
multitaper algorithms) dimension. If including ``'taper'``, you should also pass
a ``weights`` parameter.
weights : ndarray | None
The multitaper weights used for averaging across tapers. Only required if data
from :func:`~mne.time_frequency.psd_array_multitaper` with ``output='complex'``.
Weights for the ``'taper'`` dimension, if present (see ``dimnames``).
%(verbose)s
See Also
Expand Down Expand Up @@ -1596,15 +1593,15 @@ def __init__(
# (epoch, channel, [taper], freq, [segment])
_check_option("data.ndim", data.ndim, (3, 4)) # only allow one extra dimension

if "epoch" not in dimnames or list(dimnames).index("epoch") != 0:
if list(dimnames).index("epoch") != 0:
raise ValueError("'epoch' must be the first dimension of `data`.")
if events is not None and data.shape[0] != events.shape[0]:
raise ValueError(
f"The first dimension of `data` ({data.shape[0]}) must match the first "
f"dimension of `events` ({events.shape[0]})."
)

_check_data_shape(data, info, freqs, dimnames, weights)
_check_data_shape(data, info, freqs, dimnames, weights, is_epoched=True)

self.__setstate__(
dict(
Expand Down
4 changes: 2 additions & 2 deletions mne/time_frequency/tests/test_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,10 @@ def test_spectrum_array_errors():
with pytest.raises(ValueError, match=r"Expected data to have.*dimensions, got.*"):
EpochsSpectrumArray(data, info, freqs, dimnames=dimnames[:-1])
# test unrecognised dimnames (for SpectrumArray; epoch not allowed)
with pytest.raises(ValueError, match="`data` must not be epoched"):
with pytest.raises(ValueError, match="Invalid value for the 'dimnames' parameter"):
SpectrumArray(data[0, :, :], info, freqs, dimnames=("epoch", "channel"))
# test unrecognised dimnames (for EpochsSpectrumArray)
with pytest.raises(ValueError, match=r"entries of `dimnames` must be in.*, got,*"):
with pytest.raises(ValueError, match="Invalid value for the 'dimnames' parameter"):
EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "channel", "notfreq"))
# test missing dimnames
with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"):
Expand Down

0 comments on commit 01f26c3

Please sign in to comment.