Skip to content

Commit

Permalink
add plot_spectra_3d
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Dec 3, 2024
1 parent 4976524 commit da60bcf
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
65 changes: 65 additions & 0 deletions neurodsp/plts/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,68 @@ def plot_spectral_hist(freqs, power_bins, spectral_hist, spectrum_freqs=None,
if spectrum is not None:
plt_inds = np.logical_and(spectrum_freqs >= freqs[0], spectrum_freqs <= freqs[-1])
ax.plot(spectrum_freqs[plt_inds], np.log10(spectrum[plt_inds]), color='w', alpha=0.8)


@savefig
@style_plot
def plot_spectra_3D(freqs, powers, log_freqs=False, log_powers=True,
colors=None, orientation=(20, -50), **kwargs):
"""Plot a series of power spectra in a 3D plot.
Parameters
----------
freqs : 1d or 2d array or list of 1d array
Frequency vector.
powers : 2d array or list of 1d array
Power values.
log_freqs : bool, optional, default: False
Whether to plot the frequency values in log10 space.
log_powers : bool, optional, default: True
Whether to plot the power values in log10 space.
colors : str or list of str
Colors to use to plot lines.
orientation : tuple of int
Orientation to set the 3D plot.
**kwargs
Keyword arguments for customizing the plot.
Examples
--------
Plot power spectra in 3D:
>>> from neurodsp.sim import sim_combined
>>> from neurodsp.spectral import compute_spectrum
>>> sig1 = sim_combined(n_seconds=10, fs=500,
... components={'sim_powerlaw': {'exponent' : -1},
... 'sim_bursty_oscillation' : {'freq': 10}})
>>> sig2 = sim_combined(n_seconds=10, fs=500,
... components={'sim_powerlaw': {'exponent' : -1.5},
... 'sim_bursty_oscillation' : {'freq': 10}})
>>> freqs1, powers1 = compute_spectrum(sig1, fs=500)
>>> freqs2, powers2 = compute_spectrum(sig2, fs=500)
>>> plot_spectra_3D([freqs1, freqs2], [powers1, powers2])
"""

fig = plt.figure()
ax = fig.add_subplot(projection='3d')

n_spectra = len(powers)

for ind, (freq, power, _, color) in \
enumerate(zip(*prepare_multi_plot(freqs, powers, None, colors))):
ax.plot(xs=np.log10(freq) if log_freqs else freq,
ys=[ind] * len(freq),
zs=np.log10(power) if log_powers else power,
color=color,
**kwargs)

ax.set(
xlabel='Frequency (Hz)',
ylabel='Channels',
zlabel='Power',
ylim=[0, n_spectra - 1],
)

yticks = list(range(n_spectra))
ax.set_yticks(yticks, yticks)
ax.view_init(*orientation)
15 changes: 15 additions & 0 deletions neurodsp/tests/plts/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,18 @@ def test_plot_spectral_hist(tsig_comb):
spectrum=spectrum, spectrum_freqs=spectrum_freqs,
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_spectral_hist.png')

@plot_test
def test_plot_spectra_3D(tsig_comb, tsig_burst):

freqs1, powers1 = compute_spectrum(tsig_comb, FS)
freqs2, powers2 = compute_spectrum(tsig_burst, FS)

plot_spectra_3D([freqs1, freqs2], [powers1, powers2],
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_spectral3D_1.png')

plot_spectra_3D(freqs1, [powers1, powers2, powers1, powers2],
colors=['r', 'y', 'b', 'g'],
save_fig=True, file_path=TEST_PLOTS_PATH,
file_name='test_plot_spectral3D_2.png')

0 comments on commit da60bcf

Please sign in to comment.