Skip to content

Commit 93ef8f0

Browse files
author
harnessa
committedMar 1, 2022
up
1 parent 26a6116 commit 93ef8f0

File tree

5 files changed

+226
-2
lines changed

5 files changed

+226
-2
lines changed
 

‎delfi/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.npz

‎delfi/build_gauss_testset.py

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import numpy as np
2+
import h5py
3+
import time
4+
import os
5+
import imp
6+
imp_dir = '/home/aharness/repos/starshade-xy/quadrature_code'
7+
diffraq = imp.load_source("diffraq", os.path.join(imp_dir, 'diffraq', "__init__.py"))
8+
new_noise_maker = imp.load_source("new_noise_maker", os.path.join(imp_dir, "new_noise_maker.py"))
9+
from new_noise_maker import Noise_Maker
10+
11+
#Save directory
12+
session = 'n10k_sigd25'
13+
14+
#Number of samples
15+
num_imgs = 10000
16+
17+
#Gaussian distribution
18+
sig_width = 0.25
19+
20+
#SNR range
21+
snr_range = [5, 25]
22+
23+
#User options
24+
apod_name = 'm12p8'
25+
with_spiders = True
26+
wave = 403e-9
27+
num_tel_pts = 96
28+
#Telescope sizes in Lab and Space coordinates [m] (sets scaling factor)
29+
Dtel_lab = 2.201472e-3
30+
Dtel_space = 2.4
31+
#Random number generator seed
32+
seed = 88
33+
34+
############################
35+
36+
#Get random number generatore
37+
rng = np.random.default_rng(seed)
38+
39+
#Lab to space scaling
40+
lab2space = Dtel_space / Dtel_lab
41+
space2lab = 1/lab2space
42+
43+
############################
44+
45+
#Create directory
46+
save_dir = './Sim_Data'
47+
48+
#Load instance of noise maker
49+
noise_params = {'count_rate': 7, 'rng': rng, 'snr_range': snr_range,
50+
'num_tel_pts': num_tel_pts}
51+
noiser = Noise_Maker(noise_params)
52+
53+
#Specify simulation parameters
54+
params = {
55+
### Lab ###
56+
'wave': wave, #Wavelength of light [m]
57+
58+
### Telescope ###
59+
'tel_diameter': Dtel_lab, #Telescope aperture diameter [m]
60+
'num_tel_pts': num_tel_pts, #Size of grid to calculate over pupil
61+
62+
### Starshade ###
63+
#will specify apod_name after circle is run
64+
'num_petals': 12, #Number of starshade petals
65+
66+
### Saving ###
67+
'do_save': False, #Don't save data
68+
'verbose': False, #Silence output
69+
'xtras_dir': f'{imp_dir}/xtras',
70+
}
71+
72+
#Run unblocked image first (lab calibration)
73+
params['apod_name'] = 'circle'
74+
params['circle_rad'] = 25.086e-3
75+
sim = diffraq.Simulator(params)
76+
sim.setup_sim()
77+
cal_img = np.abs(sim.calculate_diffraction())**2
78+
79+
#Set to noise maker
80+
noiser.set_suppression_norm(cal_img)
81+
82+
#Save image
83+
cal_file = os.path.join(save_dir, 'calibration')
84+
85+
#New simulator for starshade images
86+
params['with_spiders'] = with_spiders
87+
params['apod_name'] = apod_name
88+
sim = diffraq.Simulator(params)
89+
sim.setup_sim()
90+
91+
#Initiate containers
92+
images = np.empty((num_imgs, sim.num_pts, sim.num_pts))
93+
positions = np.empty((num_imgs, 2))
94+
95+
tik = time.perf_counter()
96+
97+
#Loop and calculate image
98+
for i in range(num_imgs):
99+
100+
if i % (num_imgs // 10) == 0:
101+
print(f'Running step # {i} / {num_imgs} ({time.perf_counter()-tik:.0f} s)')
102+
103+
#Get random position from gaussian
104+
nx, ny = rng.multivariate_normal([0,0], np.eye(2)*(sig_width*space2lab)**2)
105+
106+
#Set shift of telescope
107+
sim.tel_shift = [nx, ny]
108+
109+
#Get random snr
110+
snr = rng.uniform(snr_range[0], snr_range[1])
111+
112+
#Get diffraction and convert to intensity
113+
img = np.abs(sim.calculate_diffraction())**2
114+
115+
#Add noise
116+
img = noiser.add_noise(img, snr)
117+
118+
#Store
119+
images[i] = img
120+
positions[i] = [nx,ny]
121+
122+
tok = time.perf_counter()
123+
124+
#Save
125+
with h5py.File(f'{save_dir}/{session}_data.h5', 'w') as f:
126+
f.create_dataset('num_tel_pts', data=num_tel_pts)
127+
f.create_dataset('sig_width', data=sig_width)
128+
f.create_dataset('lab2space', data=lab2space)
129+
f.create_dataset('seed', data=seed)
130+
f.create_dataset('snr', data=snr)
131+
f.create_dataset('images', data=images)
132+
f.create_dataset('positions', data=positions)
133+
134+
print(f'\nRan {num_imgs} images in {tok-tik:.0f} s\n')

‎delfi/fit_gmm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# run 5 GMMs and average them
1616
R = 5
1717
gmms = [ pygmmis.GMM(K=50, D=4) for _ in range(R) ]
18-
for r in range(5):
18+
for r in range(R):
1919
pygmmis.fit(gmms[r], data, init_method='minmax', w=1e-2, cutoff=5, split_n_merge=3)
2020
gmm = pygmmis.stack(gmms, np.ones(R))
2121

‎delfi/run_delfi.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def estimate_theta(gmm_theta, which="mean"):
7777
assert len(sys.argv) == 4, "usage: run_delfi.py <GMM model file> <CNN x> <CNN y>"
7878

7979
# open GMM
80-
gmm = pygmmis.GMM.from_file(sys.argv[1])
80+
gmm = pygmmis.GMM()
81+
gmm.load(sys.argv[1])
8182
# split into precomputable sub-mixtures
8283
gmm_theta, gmm_t, gmm_x = split_gmm(gmm)
8384

‎delfi/test_gauss_testset.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import numpy as np
2+
import h5py
3+
import os
4+
import torch
5+
import torch.nn.functional as F
6+
from torchvision import transforms
7+
import h5py
8+
import atexit
9+
import time
10+
import imp
11+
imp_dir = '/home/aharness/repos/starshade-xy'
12+
cnn = imp.load_source("cnn", os.path.join(imp_dir, "cnn_andrew.py"))
13+
from cnn import CNN
14+
15+
do_save = [False, True][1]
16+
17+
data_run = 'n10k_sigd25'
18+
19+
model_name = 'Newest_andrew'
20+
21+
normalization = 0.03
22+
#Telescope sizes in Lab and Space coordinates [m] (sets scaling factor)
23+
Dtel_lab = 2.201472e-3
24+
Dtel_space = 2.4
25+
lab2space = Dtel_space / Dtel_lab
26+
27+
#Directories
28+
model_dir = '/home/aharness/repos/starshade-xy/models'
29+
save_dir = 'Test_Results'
30+
31+
#######################
32+
33+
#Open test data file
34+
test_loader = h5py.File(f'./Sim_Data/{data_run}_data.h5', 'r')
35+
atexit.register(test_loader.close)
36+
37+
#Get images, amplitudes, and positions
38+
images = test_loader['images']
39+
positions = test_loader['positions']
40+
41+
#Load model
42+
model = CNN(images.shape[-1])
43+
mod_file = os.path.join(model_dir, model_name + '.pt')
44+
model.load_state_dict(torch.load(mod_file))
45+
model.eval()
46+
47+
#Transform
48+
transform = transforms.Compose([transforms.ToTensor()])
49+
50+
tik = time.perf_counter()
51+
print(f'Testing {model_name} model with {images.shape[0]} images...')
52+
53+
#Loop through images and get prediction position
54+
predicted_position = np.zeros((0,2))
55+
with torch.no_grad():
56+
57+
for img, pos in zip(images, positions):
58+
59+
#Normalize image by
60+
img /= normalization
61+
62+
#Change datatype
63+
img = img.astype('float32')
64+
65+
#Transform image
66+
img = transform(img)
67+
img = torch.unsqueeze(img, 0)
68+
69+
#Get solved position
70+
output = model(img)
71+
output = output.cpu().detach().numpy().squeeze().astype(float)
72+
73+
#Store
74+
predicted_position = np.concatenate((predicted_position, [output]))
75+
76+
tok = time.perf_counter()
77+
print(f'Done! in {tok-tik:.1f} s')
78+
79+
#Save results
80+
if do_save:
81+
82+
# format: x,y,x',y'
83+
data = np.hstack((positions[()]*lab2space, predicted_position))
84+
fname = os.path.join(save_dir, f'{data_run}__{model_name}')
85+
np.save(fname, data)
86+
87+
else:
88+
breakpoint()

0 commit comments

Comments
 (0)
Please sign in to comment.