-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtkXspecCorner.py
242 lines (205 loc) · 9.35 KB
/
tkXspecCorner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import argparse
import numpy as np
import pandas as pd
import arviz as az
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.widgets import CheckButtons, TextBox
from astropy.io import fits
import corner
import tkinter as tk
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
plt.rcParams["font.family"] = "DejaVu Serif"
plt.rcParams["mathtext.fontset"] = "dejavuserif"
plt.rc('text', usetex=False)
apptitle = 'tkXspecCorner v2023/08/08 by Federico Garcia'
#Prepare the plot scheme
plt.rcParams['lines.linewidth'] = 1.5
plt.rcParams['axes.linewidth'] = 0.8
plt.rcParams['axes.axisbelow'] = False
plt.rcParams['xtick.top'] = True
plt.rcParams['xtick.bottom'] = True
plt.rcParams['ytick.right'] = True
plt.rcParams['ytick.left'] = True
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['xtick.major.width'] = 0.8
plt.rcParams['ytick.major.width'] = 0.8
plt.rcParams['xtick.minor.width'] = 0.8
plt.rcParams['ytick.minor.width'] = 0.8
plt.rcParams['xtick.minor.visible'] = True
plt.rcParams['ytick.minor.visible'] = True
plt.rcParams['xtick.major.size'] = 3
plt.rcParams['ytick.major.size'] = 3
plt.rcParams['xtick.minor.size'] = 1.5
plt.rcParams['ytick.minor.size'] = 1.5
def UpdateCornerPlot(selectedTitles, contours, showTitles, showXYlabels, selectedAltNames):
'''Create and Update the CornerPlot based on the selectedTitles,
and plotting options like Contours, Titles and Labels'''
figcorner.clear()
if showXYlabels:
corner.corner(dataset, var_names=selectedTitles.values, filter_vars="like", fig=figcorner,
labels=selectedAltNames, label_kwargs={"fontsize": fontSize},
titles=selectedAltNames, show_titles=showTitles, title_fmt=title_fmt, title_kwargs={"fontsize": fontSize},
plot_datapoints=False, plot_density=True, plot_contours=contours, smooth=True,
quantiles=(0.16, 0.50, 0.84), use_math_text=True, bins=bins, labelpad=labelpad,
hist_kwargs={"fill": True, "color": "gray"})
else:
corner.corner(dataset, var_names=selectedTitles.values, filter_vars="like", fig=figcorner,
labels=[None for val in selectedTitles.values], label_kwargs={"fontsize": fontSize},
titles=selectedAltNames, show_titles=showTitles, title_fmt=title_fmt, title_kwargs={"fontsize": fontSize},
plot_datapoints=False, plot_density=True, plot_contours=contours, smooth=True,
quantiles=(0.16, 0.50, 0.84), use_math_text=True, bins=bins, labelpad=labelpad,
hist_kwargs={"fill": True, "color": "gray"})
figcorner.canvas.draw()
return
def UpdateAll():
for i, selVar in enumerate(selVariables):
selected[i] = selVar.get()
for i, textVar in enumerate(textVariables):
AltNames[i] = textVar.get()
selectedTitles = Titles[selected]
selectedAltNames = AltNames[selected]
contours = selected[-3]
showTitles = selected[-2]
showXYlabels = selected[-1]
UpdateCornerPlot(selectedTitles, contours, showTitles, showXYlabels, selectedAltNames)
return
if __name__ == '__main__':
'''pyXspecCorner is a CornerPlotter for XSPEC MCMC Chains saved to FITS files'''
# Organize the Parser to get Chain file, burn-in and samples to be used.
parser = argparse.ArgumentParser(prog='pyXspecCorner',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description='Make interactive CornerPlots based on XSPEC MCMC Chain FITS files.')
parser.add_argument("chain", help="Path to XSPEC Chain FITS file", type=str, default='chain.fits')
parser.add_argument("--burn", help="Samples to Burn In", type=int, default=0, nargs='?')
parser.add_argument("--samples",help="Samples used in CornerPlot (-1 to use all)", type=int, default=-1, nargs='?')
parser.add_argument("--bins",help="Number of Bins used in CornerPlot", type=int, default=30, nargs='?')
parser.add_argument("--format",help="Numeric format of Titles and XYlabels in CornerPlot", type=str, default='.2f', nargs='?')
parser.add_argument("--labelpad",help="Fractional label padding for Titles and XYlabels", type=float, default=0.05, nargs='?')
parser.add_argument("--fontSize",help="Font Size for Titles, XYTicks and XYlabels", type=str, default='xx-small', nargs='?')
args = parser.parse_args()
# Use the parsed arguments to get the selected data from the Chain FITS file
chainName = args.chain
BurnIn = int(args.burn)
Samples = int(args.samples)
bins = int(args.bins)
title_fmt = args.format
labelpad = args.labelpad
fontSize = args.fontSize
# Set dynamic plotting styles
plt.rc('xtick',labelsize=fontSize)
plt.rc('ytick',labelsize=fontSize)
plt.rc('axes',labelsize=fontSize)
chain = fits.open(chainName)
nFields = int(chain[1].header['TFIELDS'])
ChainLength = int(chain[1].header['NAXIS2'])
if Samples<0:
Samples = ChainLength
idx = np.arange(min(BurnIn,0), ChainLength)
else:
idx = np.random.randint(low=min(BurnIn,0), high=ChainLength, size=Samples)
print()
print('===============================================')
print(' Loading Chain: {}'.format(chainName))
print(' Chain Length: {}'.format(ChainLength))
print(' Number of Fields: {}'.format(nFields))
print('===============================================')
print(' Burning in {} samples'.format(BurnIn))
print(' Using {} samples for plotting purposes'.format(Samples))
print('===============================================')
print()
# Create the DataFrame for the CornerPlot and fill it with the Data and Titles.
df = pd.DataFrame()
titles, ttypes = [], []
for i in range(nFields):
ttype = chain[1].header['TTYPE{}'.format(i+1)]
tform = chain[1].header['TFORM{}'.format(i+1)]
try:
tunit = chain[1].header['TUNIT{}'.format(i+1)]
except:
tunit = ''
try:
tt = ttype.split('__')
if len(tt) == 3:
tname, tmodel, tnum = tt
elif len(tt) == 2:
tname, tnum = tt
tmodel = ''
else:
tnum = str(int(tnum)+1)
tname = 'Chi-Squared'
tunit = ''
tmodel = ''
except:
tmodel = ''
tunit = ''
ttypes.append(ttype)
if tunit and tmodel:
title = '{} {}. {} [{}]'.format(tmodel,tnum,tname,tunit)
elif tmodel:
title = '{} {}. {}'.format(tmodel,tnum,tname)
elif tunit:
title = '{}. {} [{}]'.format(tnum,tname,tunit)
else:
title = '{}. {}'.format(tnum,tname)
titles.append(title)
df[title] = chain[1].data[ttype][idx]
df["chain"] = 0
df["draw"] = np.arange(len(df), dtype=int)
df = df.set_index(["chain", "draw"])
xdata = xr.Dataset.from_dataframe(df)
dataset = az.InferenceData(posterior=xdata)
func_dict = {
"median": lambda x: np.percentile(x, 50),
"5%": lambda x: np.percentile(x, 5),
"16%": lambda x: np.percentile(x, 16),
"84%": lambda x: np.percentile(x, 84),
"95%": lambda x: np.percentile(x, 95),
}
print('Summary statistics based on selected posterior samples:')
print()
print(az.summary(dataset, group='posterior', stat_funcs=func_dict, extend=False))
print()
print('Building interactive plot...')
print()
# Add some Plotting abilities
titles.append('Draw Contours')
contours = False
titles.append('Show Titles')
showTitles = True
titles.append('Show XY Labels')
showXYlabels = False
# Make the Titles and Pre-select Chi-square and Titles
Titles = pd.Series(titles)
AltNames = np.array(Titles)
selected = [False for Title in Titles]
selected[-4] = True
selected[-2] = True
selectedTitles = Titles[selected]
selectedAltNames = AltNames[selected]
# Create the two interactive figures and fill them: Buttons and CornerPlot
figcorner = plt.Figure(figsize=(6,6), dpi=140)
app = tk.Tk()
app.title(apptitle)
ParamTitle = tk.Label(app, text='Parameters', font="sans 12 bold")
ParamTitle.grid(row=0, column=1, columnspan=2, rowspan=1)
selVariables, chkButtons, textVariables, txtButtons = [], [], [], []
for i, Title in enumerate(Titles):
selVariables.append(tk.BooleanVar())
chkButton = tk.Checkbutton(app, text=Titles[i], var=selVariables[i])
chkButtons.append(chkButton)
chkButton.grid(row=i+1, column=1)
textVariables.append(tk.StringVar())
textVariables[i].set(Titles[i])
txtButton = tk.Entry(app, text=Titles[i], textvariable=textVariables[i])
txtButtons.append(txtButton)
txtButton.grid(row=i+1, column=2)
updateButton = tk.Button(app, text='Update Corner Plot', font="sans 10 bold", command=UpdateAll)
updateButton.grid(row=len(Titles)+1, column=1, columnspan=2, rowspan=1)
canvas = FigureCanvasTkAgg(figcorner, app)
canvas.get_tk_widget().grid(row=0,column=4,columnspan=10,rowspan=len(Titles)+1)
toolbar = NavigationToolbar2Tk(canvas, app, pack_toolbar=False)
toolbar.grid(row=len(Titles)+1,column=4,columnspan=10,rowspan=1)
UpdateCornerPlot(selectedTitles, contours, showTitles, showXYlabels, selectedAltNames)
app.mainloop()