From 7badfbf2f4543a968785e6de1d10c17e3dec2ef9 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 30 Jul 2024 19:54:55 +0200 Subject: [PATCH] Update (Epochs)SpectrumArray docstrings --- mne/time_frequency/spectrum.py | 62 +++++++++++-------- mne/time_frequency/tests/test_spectrum.py | 73 +++++++++++------------ mne/utils/docs.py | 12 ++-- 3 files changed, 81 insertions(+), 66 deletions(-) diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index e35942b2a88..cc643490915 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -1098,6 +1098,8 @@ class Spectrum(BaseSpectrum): The weights for each taper. Only present if spectra computed with ``method='multitaper'`` and ``output='complex'``. + .. versionadded:: 1.8 + See Also -------- EpochsSpectrum @@ -1214,28 +1216,28 @@ def __getitem__(self, item): return BaseRaw._getitem(self, item, return_times=False) -def _check_data_shape(data, info, freqs, dimnames, weights, is_epoched): - if data.ndim != len(dimnames): +def _check_data_shape(data, info, freqs, dim_names, weights, is_epoched): + if data.ndim != len(dim_names): raise ValueError( - f"Expected data to have {len(dimnames)} dimensions, got {data.ndim}." + f"Expected data to have {len(dim_names)} dimensions, got {data.ndim}." ) allowed_dims = ["epoch", "channel", "freq", "segment", "taper"] if not is_epoched: allowed_dims.remove("epoch") # TODO maybe we should be nice and allow plural versions of each dimname? - for dim in dimnames: - _check_option("dimnames", dim, allowed_dims) - if "channel" not in dimnames or "freq" not in dimnames: - raise ValueError("Both 'channel' and 'freq' must be present in `dimnames`.") + for dim in dim_names: + _check_option("dim_names", dim, allowed_dims) + if "channel" not in dim_names or "freq" not in dim_names: + raise ValueError("Both 'channel' and 'freq' must be present in `dim_names`.") - if list(dimnames).index("channel") != int(is_epoched): + if list(dim_names).index("channel") != int(is_epoched): raise ValueError( f"'channel' must be the {'second' if is_epoched else 'first'} dimension of " "the data." ) want_n_chan = _pick_data_channels(info).size - got_n_chan = data.shape[list(dimnames).index("channel")] + got_n_chan = data.shape[list(dim_names).index("channel")] if got_n_chan != want_n_chan: raise ValueError( f"The number of channels in `data` ({got_n_chan}) must match the number of " @@ -1244,25 +1246,25 @@ def _check_data_shape(data, info, freqs, dimnames, weights, is_epoched): # given we limit max array size and ensure channel & freq dims present, only one of # taper or segment can be present - if "taper" in dimnames: - if dimnames[-2] != "taper": # _psd_from_mt assumes this (called when plotting) + if "taper" in dim_names: + if dim_names[-2] != "taper": # _psd_from_mt assumes this (called when plotting) raise ValueError( "'taper' must be the second to last dimension of the data." ) # expect weights for each taper actual = None if weights is None else weights.size - expected = data.shape[list(dimnames).index("taper")] + expected = data.shape[list(dim_names).index("taper")] if actual != expected: raise ValueError( f"Expected size of `weights` to be {expected} to match 'n_tapers' in " f"`data`, got {actual}." ) - elif "segment" in dimnames and dimnames[-1] != "segment": + elif "segment" in dim_names and dim_names[-1] != "segment": raise ValueError("'segment' must be the last dimension of the data.") # freq being in wrong position ruled out by above checks want_n_freq = freqs.size - got_n_freq = data.shape[list(dimnames).index("freq")] + got_n_freq = data.shape[list(dim_names).index("freq")] if got_n_freq != want_n_freq: raise ValueError( f"The number of frequencies in `data` ({got_n_freq}) must match the number " @@ -1280,14 +1282,18 @@ class SpectrumArray(Spectrum): The spectra for each channel. %(info_not_none)s %(freqs_tfr_array)s - dimnames : tuple of str + dim_names : tuple of str The name of the dimensions in the data, in the order they occur. Must contain ``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g., multitaper algorithms) dimension. If including ``'taper'``, you should also pass a ``weights`` parameter. + + .. versionadded:: 1.8 weights : ndarray | None - Weights for the ``'taper'`` dimension, if present (see ``dimnames``). + Weights for the ``'taper'`` dimension, if present (see ``dim_names``). + + .. versionadded:: 1.8 %(verbose)s See Also @@ -1310,7 +1316,7 @@ def __init__( data, info, freqs, - dimnames=("channel", "freq"), + dim_names=("channel", "freq"), weights=None, *, verbose=None, @@ -1318,14 +1324,14 @@ def __init__( # (channel, [taper], freq, [segment]) _check_option("data.ndim", data.ndim, (2, 3)) # only allow one extra dimension - _check_data_shape(data, info, freqs, dimnames, weights, is_epoched=False) + _check_data_shape(data, info, freqs, dim_names, weights, is_epoched=False) self.__setstate__( dict( method="unknown", data=data, sfreq=info["sfreq"], - dims=dimnames, + dims=dim_names, freqs=freqs, inst_type_str="Array", data_type=( @@ -1376,6 +1382,8 @@ class EpochsSpectrum(BaseSpectrum, GetEpochsMixin): The weights for each taper. Only present if spectra computed with ``method='multitaper'`` and ``output='complex'``. + .. versionadded:: 1.8 + See Also -------- EpochsSpectrumArray @@ -1554,14 +1562,18 @@ class EpochsSpectrumArray(EpochsSpectrum): %(freqs_tfr_array)s %(events_epochs)s %(event_id)s - dimnames : tuple of str + dim_names : tuple of str The name of the dimensions in the data, in the order they occur. Must contain ``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g., multitaper algorithms) dimension. If including ``'taper'``, you should also pass a ``weights`` parameter. + + .. versionadded:: 1.8 weights : ndarray | None - Weights for the ``'taper'`` dimension, if present (see ``dimnames``). + Weights for the ``'taper'`` dimension, if present (see ``dim_names``). + + .. versionadded:: 1.8 %(verbose)s See Also @@ -1585,7 +1597,7 @@ def __init__( freqs, events=None, event_id=None, - dimnames=("epoch", "channel", "freq"), + dim_names=("epoch", "channel", "freq"), weights=None, *, verbose=None, @@ -1593,7 +1605,7 @@ def __init__( # (epoch, channel, [taper], freq, [segment]) _check_option("data.ndim", data.ndim, (3, 4)) # only allow one extra dimension - if list(dimnames).index("epoch") != 0: + if list(dim_names).index("epoch") != 0: raise ValueError("'epoch' must be the first dimension of `data`.") if events is not None and data.shape[0] != events.shape[0]: raise ValueError( @@ -1601,14 +1613,14 @@ def __init__( f"dimension of `events` ({events.shape[0]})." ) - _check_data_shape(data, info, freqs, dimnames, weights, is_epoched=True) + _check_data_shape(data, info, freqs, dim_names, weights, is_epoched=True) self.__setstate__( dict( method="unknown", data=data, sfreq=info["sfreq"], - dims=dimnames, + dims=dim_names, freqs=freqs, inst_type_str="Array", data_type=( diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index d9e1917ce61..980df42d791 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -439,15 +439,6 @@ def _check_spectrum_equivalent(spect1, spect2, tmp_path): assert_array_equal(spect1.freqs, spect2.freqs) -def _get_dimnames(kind, method, output, average): - dimnames = ("epoch", "channel") if kind == "epochs" else ("channel",) - if method == "welch": - dimnames += ("freq",) if average else ("freq", "segment") - else: # i.e. multitaper - dimnames += ("freq",) if output == "power" else ("taper", "freq") - return dimnames - - def test_spectrum_array_errors(): """Test (Epochs)SpectrumArray constructor errors.""" n_epochs = 10 @@ -457,23 +448,23 @@ def test_spectrum_array_errors(): sfreq = 100 rng = np.random.default_rng(44) data = rng.random((n_epochs, n_chans, n_freqs)) - dimnames = ("epoch", "channel", "freq") + dim_names = ("epoch", "channel", "freq") info = create_info(n_chans, sfreq, "eeg") # test incorrect ndims (for SpectrumArray; allows 2-3D data) with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"): - SpectrumArray(data[0, 0, :], info, freqs, dimnames=dimnames) + SpectrumArray(data[0, 0, :], info, freqs, dim_names=dim_names) with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"): - SpectrumArray(np.expand_dims(data, axis=3), info, freqs, dimnames=dimnames) + SpectrumArray(np.expand_dims(data, axis=3), info, freqs, dim_names=dim_names) # test incorrect ndims (for EpochsSpectrumArray; allows 3-4D data) with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"): - EpochsSpectrumArray(data[0, :, :], info, freqs, dimnames=dimnames) + EpochsSpectrumArray(data[0, :, :], info, freqs, dim_names=dim_names) with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"): EpochsSpectrumArray( - np.expand_dims(data, axis=(3, 4)), info, freqs, dimnames=dimnames + np.expand_dims(data, axis=(3, 4)), info, freqs, dim_names=dim_names ) # test incorrect epochs location with pytest.raises(ValueError, match="'epoch' must be the first dimension"): - EpochsSpectrumArray(data, info, freqs, dimnames=("channel", "epoch", "freq")) + EpochsSpectrumArray(data, info, freqs, dim_names=("channel", "epoch", "freq")) # test mismatching events shape events = np.vstack( ( @@ -483,36 +474,40 @@ def test_spectrum_array_errors(): ) ).T with pytest.raises(ValueError, match=r"first dimension.*dimension of `events`"): - EpochsSpectrumArray(data, info, freqs, events, dimnames=dimnames) + EpochsSpectrumArray(data, info, freqs, events, dim_names=dim_names) # test data-dimname mismatch with pytest.raises(ValueError, match=r"Expected data to have.*dimensions, got.*"): - EpochsSpectrumArray(data, info, freqs, dimnames=dimnames[:-1]) - # test unrecognised dimnames (for SpectrumArray; epoch not allowed) - with pytest.raises(ValueError, match="Invalid value for the 'dimnames' parameter"): - SpectrumArray(data[0, :, :], info, freqs, dimnames=("epoch", "channel")) - # test unrecognised dimnames (for EpochsSpectrumArray) - with pytest.raises(ValueError, match="Invalid value for the 'dimnames' parameter"): - EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "channel", "notfreq")) - # test missing dimnames + EpochsSpectrumArray(data, info, freqs, dim_names=dim_names[:-1]) + # test unrecognised dim_names (for SpectrumArray; epoch not allowed) + with pytest.raises(ValueError, match="Invalid value for the 'dim_names' parameter"): + SpectrumArray(data[0, :, :], info, freqs, dim_names=("epoch", "channel")) + # test unrecognised dim_names (for EpochsSpectrumArray) + with pytest.raises(ValueError, match="Invalid value for the 'dim_names' parameter"): + EpochsSpectrumArray( + data, info, freqs, dim_names=("epoch", "channel", "notfreq") + ) + # test missing dim_names with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"): - EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "channel", "channel")) + EpochsSpectrumArray( + data, info, freqs, dim_names=("epoch", "channel", "channel") + ) with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"): - EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "freq", "freq")) + EpochsSpectrumArray(data, info, freqs, dim_names=("epoch", "freq", "freq")) with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"): - EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "epoch", "epoch")) + EpochsSpectrumArray(data, info, freqs, dim_names=("epoch", "epoch", "epoch")) # test incorrect channel location (for SpectrumArray; must be 1st dim) with pytest.raises(ValueError, match="'channel' must be the first dimension"): - SpectrumArray(data[0, :, :], info, freqs, dimnames=("freq", "channel")) + SpectrumArray(data[0, :, :], info, freqs, dim_names=("freq", "channel")) # test incorrect channel location (for EpochsSpectrumArray; must be 2nd dim) with pytest.raises(ValueError, match="'channel' must be the second dimension"): - EpochsSpectrumArray(data, info, freqs, dimnames=("epoch", "freq", "channel")) + EpochsSpectrumArray(data, info, freqs, dim_names=("epoch", "freq", "channel")) # test mismatching number of channels with pytest.raises(ValueError, match=r"number of channels.*good data channels"): - EpochsSpectrumArray(data[:, :-1, :], info, freqs, dimnames=dimnames) + EpochsSpectrumArray(data[:, :-1, :], info, freqs, dim_names=dim_names) # test incorrect taper position with pytest.raises(ValueError, match="'taper' must be the second to last dim"): EpochsSpectrumArray( - np.expand_dims(data, axis=3), info, freqs, dimnames=dimnames + ("taper",) + np.expand_dims(data, axis=3), info, freqs, dim_names=dim_names + ("taper",) ) # test incorrect weight size with pytest.raises(ValueError, match=r"Expected size of `weights` to be.*, got.*"): @@ -520,7 +515,7 @@ def test_spectrum_array_errors(): np.expand_dims(data, axis=2), info, freqs, - dimnames=("epoch", "channel", "taper", "freq"), + dim_names=("epoch", "channel", "taper", "freq"), weights=None, ) with pytest.raises(ValueError, match=r"Expected size of `weights` to be.*, got.*"): @@ -528,7 +523,7 @@ def test_spectrum_array_errors(): np.expand_dims(data, axis=2), info, freqs, - dimnames=("epoch", "channel", "taper", "freq"), + dim_names=("epoch", "channel", "taper", "freq"), weights=np.ones((1, 2, 1)), ) # test incorrect segment position @@ -537,11 +532,11 @@ def test_spectrum_array_errors(): np.expand_dims(data, axis=2), info, freqs, - dimnames=("epoch", "channel", "segment", "freq"), + dim_names=("epoch", "channel", "segment", "freq"), ) # test mismatching number of frequencies with pytest.raises(ValueError, match=r"number of frequencies.*number of elements"): - EpochsSpectrumArray(data[:, :, :-1], info, freqs, dimnames=dimnames) + EpochsSpectrumArray(data[:, :, :-1], info, freqs, dim_names=dim_names) @pytest.mark.parametrize( @@ -554,7 +549,11 @@ def test_spectrum_array_errors(): ) def test_spectrum_array(kind, method, output, average, tmp_path, request): """Test EpochsSpectrumArray and SpectrumArray constructors.""" - dimnames = _get_dimnames(kind, method, output, average) + dim_names = ("epoch", "channel") if kind == "epochs" else ("channel",) + if method == "welch": + dim_names += ("freq",) if average else ("freq", "segment") + else: # i.e. multitaper + dim_names += ("freq",) if output == "power" else ("taper", "freq") if method == "welch" and output == "power" and average: spectrum = request.getfixturevalue(f"{kind}_spectrum") else: @@ -569,7 +568,7 @@ def test_spectrum_array(kind, method, output, average, tmp_path, request): data=data, info=spectrum.info, freqs=freqs, - dimnames=dimnames, + dim_names=dim_names, weights=spectrum.weights, ) _check_spectrum_equivalent(spectrum, spect_arr, tmp_path) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index ff9e11ee776..57a0999fd1e 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -2922,11 +2922,15 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["notes_plot_psd_meth"] = _notes_plot_psd.format("method") docdict["notes_spectrum_array"] = """ -It is assumed that the data passed in represent spectral *power* (not amplitude, -phase, model coefficients, etc) and downstream methods (such as +If the data passed in is real-valued, it is assumed to represent spectral *power* (not +amplitude, phase, etc), and downstream methods (such as :meth:`~mne.time_frequency.SpectrumArray.plot`) assume power data. If you pass in -something other than power, at the very least axis labels will be inaccurate (and -other things may also not work or be incorrect). +real-valued data that is not power, axis labels will be incorrect. + +If the data passed in is complex-valued, it is assumed to represent Fourier +coefficients. Downstream plotting methods will treat the data as such, attempting to +convert this to power before visualisation. If you pass in complex-valued data that is +not Fourier coefficients, axis labels will be incorrect. """ docdict["notes_timefreqs_tfr_plot_joint"] = """