-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdpss.py
90 lines (80 loc) · 3.4 KB
/
dpss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
# −*− coding:utf-8 −*−
import numpy as np
from scipy.linalg import eigh_tridiagonal
from scipy.linalg import toeplitz
import matplotlib.pyplot as plt
import pyfftw, multiprocessing
class DPSS(object):
'''
A class for generating discrete prolate spheroidal sequences, plotting them and their Fourier transforms
'''
n_thread = multiprocessing.cpu_count() # number of CPU cores
def __init__(self, N=1000, NW=3, K=4, eigenvalue=False):
'''
N: sequence length
W: half resolution bandwidth
NW: N * W, preferable for integers
K: number of tapers ordered descendently by their eigenvalues, should be no more than N
'''
self.N = N
self.W = NW / N
self.K = K
self.gen_sequences(eigenvalue)
def gen_sequences(self, eigenvalue):
'''
generate the discrete prolate spheroidal sequences in the time domain
'''
diag_main = ((self.N-1)/2-np.arange(self.N))**2 * np.cos(2*np.pi*self.W)
diag_off = np.arange(1, self.N) * np.arange(self.N-1, 0, -1) / 2
vecs = eigh_tridiagonal(diag_main, diag_off, select='i', select_range=(self.N-self.K,self.N-1))[1]
self.vecs = (vecs * np.where(vecs[0,:]>0, 1, -1)).T[::-1] # normalized energy, polarity follows Slepian convention
if eigenvalue:
A = toeplitz(np.insert( np.sin(2*np.pi*self.W*np.arange(1,self.N))/(np.pi*np.arange(1,self.N)), 0, 2*self.W ))
self.vals = np.diag(self.vecs @ A @ self.vecs.T) # @ is matrix multiplication
def plot_sequences(self):
plt.close("all")
fig, ax = plt.subplots()
index = np.arange(self.N)
for val, vec in zip(self.vals, self.vecs):
ax.plot(index, vec, label="{:.7f}".format(val))
ax.axhline(color='k')
ax.legend()
ax.set_xlim([index.min(), index.max()])
ax.set_xlabel("index")
ax.set_ylabel("amplitude")
ax.set_title(r"$NW$ = {:d}".format(int(self.N*self.W)))
plt.tight_layout(.5)
plt.show()
def gen_spectra(self, n_point=None):
'''
calculate the corresponding Fourier transforms in the frequency domain
'''
if n_point is None:
n_point = int( np.power(2, np.ceil(np.log2(self.N))) )
dummy = pyfftw.empty_aligned((self.K, self.N))
fft = pyfftw.builders.fft(dummy, n=n_point, overwrite_input=True, threads=self.n_thread)
return fft(self.vecs)
def plot_spectra(self, spectra):
n_point = spectra.shape[1]
spectra = np.absolute(spectra[:,:(n_point+1)//2])**2 # positive half, converted to power
plt.close("all")
fig, ax = plt.subplots()
frequency = np.arange((n_point+1)//2) / n_point
for val, spectrum in zip(self.vals, spectra):
ax.fill_between(frequency, spectrum, alpha=.5, label="{:.7f}".format(val))
ax.fill_between(frequency, np.mean(spectra, axis=0), alpha=.5, label="average")
ax.axvline(self.W, color='k')
ax.legend()
ax.set_yscale("log")
ax.set_xlim([0, .5])
ax.set_ylim(ymin=1e-6)
ax.set_xlabel("frequency")
ax.set_ylabel("power")
ax.set_title(r"$NW$ = {:d}".format(int(self.N*self.W)))
plt.tight_layout(.5)
plt.show()
if __name__ == "__main__":
dpss = DPSS(32, 4, 8, True)
dpss.plot_sequences()
dpss.plot_spectra(dpss.gen_spectra(1024))