-
Notifications
You must be signed in to change notification settings - Fork 1
/
view_coeffs.py
122 lines (92 loc) · 3.88 KB
/
view_coeffs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import pickle
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
class SEOBNRv4AmpPhase(Dataset):
def __init__(self, filename):
with open(filename, 'rb') as f:
[lambda_values, coeffs, eim_basis, eim_indices] = pickle.load(f)
self.x = lambda_values
self.y = coeffs
self.basis = eim_basis
self.n_coeffs = eim_indices.shape[0]
def __getitem__(self, item):
return self.x[item, :], self.y[item, :]
def __len__(self):
return self.x.shape[0]
# train_dataset_amp = SEOBNRv4AmpPhase(filename='q1to8_s0.99_both/amp_rel_sur/tol_1e-10.pkl')
train_dataset_phi = SEOBNRv4AmpPhase(filename='q1to8_s0.99_both/phi_rel_sur/tol_1e-10.pkl')
cmap = matplotlib.cm.get_cmap('plasma')
#
# fig = plt.figure(figsize=(14, 4))
# ax = fig.add_subplot(131, projection='3d')
# sc = ax.scatter(train_dataset_amp.x[::30, 1], train_dataset_amp.x[::30, 2], train_dataset_amp.y[::30, 15],
# c=train_dataset_amp.x[::30, 0], cmap=cmap)
# # plt.suptitle('Amplitude 1st coefficient')
# ax.set_ylabel(r'$\chi_2$')
# ax.set_xlabel(r'$\chi_1$')
#
# ax = fig.add_subplot(132, projection='3d')
# sc = ax.scatter(train_dataset_amp.x[::30, 1], train_dataset_amp.x[::30, 2], train_dataset_amp.y[::30, 16],
# c=train_dataset_amp.x[::30, 0], cmap=cmap)
# # plt.suptitle('Amplitude 2nd coefficient')
# ax.set_ylabel(r'$\chi_2$')
# ax.set_xlabel(r'$\chi_1$')
# # plt.savefig('amp_coeff_2.pdf', bbox_inches='tight')
#
# ax = fig.add_subplot(133, projection='3d')
# sc = ax.scatter(train_dataset_amp.x[::30, 1], train_dataset_amp.x[::30, 2], train_dataset_amp.y[::30, 17],
# c=train_dataset_amp.x[::30, 0], cmap=cmap)
# # plt.suptitle('Amplitude 3rd coefficient')
# ax.set_ylabel(r'$\chi_2$')
# ax.set_xlabel(r'$\chi_1$')
# # plt.savefig('amp_coeff_3.pdf', bbox_inches='tight')
#
# ax.set_ylabel(r'$\chi_2$')
# ax.set_xlabel(r'$\chi_1$')
#
# # cbar = plt.colorbar(sc)
# cbaxes = fig.add_axes([0.92, 0.1, 0.03, 0.8])
# cbar = plt.colorbar(sc, cax=cbaxes)
# # plt.title('Phase coefficients')
# cbar.set_label(r'$q$', rotation=0)
# plt.subplots_adjust(wspace=0, hspace=0)
# plt.suptitle('Amplitude coefficients')
# plt.savefig('amp_coeffs_15-17.pdf', bbox_inches='tight')
# PHASE
fig = plt.figure(figsize=(14, 4))
ax = fig.add_subplot(131, projection='3d')
sc = ax.scatter(train_dataset_phi.x[::30, 1], train_dataset_phi.x[::30, 2], train_dataset_phi.y[::30, 5],
c=train_dataset_phi.x[::30, 0], cmap=cmap)
ax.set_ylabel(r'$\chi_2$')
ax.set_xlabel(r'$\chi_1$')
# plt.suptitle('Phase 1st coefficient')
ax = fig.add_subplot(132, projection='3d')
sc = ax.scatter(train_dataset_phi.x[::30, 1], train_dataset_phi.x[::30, 2], train_dataset_phi.y[::30, 6],
c=train_dataset_phi.x[::30, 0], cmap=cmap)
# plt.suptitle('Phase 2nd coefficient')
ax.set_ylabel(r'$\chi_2$')
ax.set_xlabel(r'$\chi_1$')
# plt.savefig('phi_coeff_2.pdf', bbox_inches='tight')
ax = fig.add_subplot(133, projection='3d')
sc = ax.scatter(train_dataset_phi.x[::30, 1], train_dataset_phi.x[::30, 2], train_dataset_phi.y[::30, 7],
c=train_dataset_phi.x[::30, 0], cmap=cmap)
# plt.suptitle('Phase 3rd coefficient')
ax.set_ylabel(r'$\chi_2$')
ax.set_xlabel(r'$\chi_1$')
# plt.savefig('phi_coeff_3.pdf', bbox_inches='tight')
# plt.ylabel(r'$\chi_2$')
# plt.xlabel(r'$\chi_1$')
# cbar = plt.colorbar(sc)
cbaxes = fig.add_axes([0.92, 0.1, 0.03, 0.8])
cbar = plt.colorbar(sc, cax=cbaxes)
# plt.title('Phase coefficients')
cbar.set_label(r'$q$', rotation=0)
plt.subplots_adjust(wspace=0, hspace=0)
plt.suptitle('Phase coefficients')
plt.savefig('phi_coeffs_5-7.pdf', bbox_inches='tight')
# fig = plt.figure(figsize=(4, 4))
# fig = plt.figure(figsize=(4, 4))
# fig = plt.figure(figsize=(4, 4))
# fig = plt.figure(figsize=(4, 4))
plt.show()