-
Notifications
You must be signed in to change notification settings - Fork 8
/
redcal.py
2115 lines (1864 loc) · 121 KB
/
redcal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -*- coding: utf-8 -*-
# Copyright 2019 the HERA Project
# Licensed under the MIT License
import numpy as np
from copy import deepcopy
import argparse
import os
import linsolve
from itertools import chain
from . import utils
from .noise import predict_noise_variance_from_autos, infer_dt
from .datacontainer import DataContainer, RedDataContainer
from .utils import split_pol, conj_pol, split_bl, reverse_bl, join_bl, join_pol, comply_pol, per_antenna_modified_z_scores
from .io import HERAData, HERACal, write_cal, save_redcal_meta
from .apply_cal import calibrate_in_place
SEC_PER_DAY = 86400.
IDEALIZED_BL_TOL = 1e-8 # bl_error_tol for redcal.get_reds when using antenna positions calculated from reds
def get_pos_reds(antpos, bl_error_tol=1.0, include_autos=False):
""" Figure out and return list of lists of redundant baseline pairs. Ordered by length. All baselines
in a group have the same orientation with a preference for positive b_y and, when b_y==0, positive
b_x where b((i,j)) = pos(j) - pos(i).
Args:
antpos: dictionary of antenna positions in the form {ant_index: np.array([x,y,z])}. 1D and 2D also OK.
bl_error_tol: the largest allowable difference between baselines in a redundant group
(in the same units as antpos). Normally, this is up to 4x the largest antenna position error.
include_autos: bool, optional
if True, include autos in the list of pos_reds. default is False
Returns:
reds: list (sorted by baseline legnth) of lists of redundant tuples of antenna indices (no polarizations),
sorted by index with the first index of the first baseline the lowest in the group.
"""
keys = list(antpos.keys())
reds = {}
assert np.all([len(pos) <= 3 for pos in antpos.values()]), 'Get_pos_reds only works in up to 3 dimensions.'
ap = {ant: np.pad(pos, (0, 3 - len(pos)), mode='constant') for ant, pos in antpos.items()} # increase dimensionality
array_is_flat = np.all(np.abs(np.array(list(ap.values()))[:, 2] - np.mean(list(ap.values()), axis=0)[2]) < bl_error_tol / 4.0)
p_or_m = (0, -1, 1)
if array_is_flat:
epsilons = [[dx, dy, 0] for dx in p_or_m for dy in p_or_m]
else:
epsilons = [[dx, dy, dz] for dx in p_or_m for dy in p_or_m for dz in p_or_m]
def check_neighbors(delta): # Check to make sure reds doesn't have the key plus or minus rounding error
for epsilon in epsilons:
newKey = (delta[0] + epsilon[0], delta[1] + epsilon[1], delta[2] + epsilon[2])
if newKey in reds:
return newKey
return
for i, ant1 in enumerate(keys):
if include_autos:
start_ind = i
else:
start_ind = i + 1
for ant2 in keys[start_ind:]:
bl_pair = (ant1, ant2)
delta = tuple(np.round(1.0 * (np.array(ap[ant2]) - np.array(ap[ant1])) / bl_error_tol).astype(int))
new_key = check_neighbors(delta)
if new_key is None: # forward baseline has no matches
new_key = check_neighbors(tuple([-d for d in delta]))
if new_key is not None: # reverse baseline does have a match
bl_pair = (ant2, ant1)
if new_key is not None: # either the forward or reverse baseline has a match
reds[new_key].append(bl_pair)
else: # this baseline is entirely new
if delta[0] <= 0 or (delta[0] == 0 and delta[1] <= 0) or (delta[0] == 0 and delta[1] == 0 and delta[2] <= 0):
delta = tuple([-d for d in delta])
bl_pair = (ant2, ant1)
reds[delta] = [bl_pair]
# sort reds by length and each red to make sure the first antenna of the first bl in each group is the lowest antenna number
orderedDeltas = [delta for (length, delta) in sorted(zip([np.linalg.norm(delta) for delta in reds.keys()], reds.keys()))]
return [sorted(reds[delta]) if sorted(reds[delta])[0][0] == np.min(reds[delta])
else sorted([reverse_bl(bl) for bl in reds[delta]]) for delta in orderedDeltas]
def add_pol_reds(reds, pols=['nn'], pol_mode='1pol'):
""" Takes positonal reds (antenna indices only, no polarizations) and converts them
into baseline tuples with polarization, depending on pols and pol_mode specified.
Args:
reds: list of list of antenna index tuples considered redundant
pols: a list of polarizations e.g. ['nn', 'ne', 'en', 'ee']
pol_mode: polarization mode of calibration
'1pol': 1 antpol and 1 vispol (e.g. 'Jnn' and 'nn'). Default.
'2pol': 2 antpols, no cross-vispols (e.g. 'Jnn','Jee' and 'nn','ee')
'4pol': 2 antpols, 4 vispols (e.g. 'Jnn','Jee' and 'nn','ne','en','ee')
'4pol_minV': 2 antpols, 4 vispols in data but assuming V_ne = V_en in model
Returns:
reds: list of lists of redundant baseline tuples, e.g. (ind1,ind2,pol)
"""
# pre-process to ensure pols complies w/ hera_cal polarization convention
pols = [comply_pol(p) for p in pols]
redsWithPols, didBothCrossPolsForMinV = [], False
for pol in pols:
if pol_mode != '4pol_minV' or pol[0] == pol[1]:
redsWithPols += [[bl + (pol,) for bl in bls] for bls in reds]
elif pol_mode == '4pol_minV' and not didBothCrossPolsForMinV:
# Combine together e.g. 'ne' and 'en' visibilities as redundant
redsWithPols += [([bl + (pol,) for bl in bls]
+ [bl + (conj_pol(pol),) for bl in bls]) for bls in reds]
didBothCrossPolsForMinV = True
return redsWithPols
def get_reds(antpos, pols=['nn'], pol_mode='1pol', bl_error_tol=1.0, include_autos=False):
""" Combines redcal.get_pos_reds() and redcal.add_pol_reds(). See their documentation.
Args:
antpos: dictionary of antenna positions in the form {ant_index: np.array([x,y,z])}.
pols: a list of polarizations e.g. ['nn', 'ne', 'en', 'ee']
pol_mode: polarization mode of calibration
'1pol': 1 antpol and 1 vispol (e.g. 'Jnn' and 'nn'). Default.
'2pol': 2 antpols, no cross-vispols (e.g. 'Jnn','Jee' and 'nn','ee')
'4pol': 2 antpols, 4 vispols (e.g. 'Jnn','Jee' and 'nn','ne','en','ee')
'4pol_minV': 2 antpols, 4 vispols in data but assuming V_ne = V_en in model
bl_error_tol: the largest allowable difference between baselines in a redundant group
(in the same units as antpos). Normally, this is up to 4x the largest antenna position error.
include_autos: bool, optional
if true, include autocorr redundant group
Default is false.
Returns:
reds: list (sorted by baseline length) of lists of redundant baseline tuples, e.g. (ind1,ind2,pol).
Each interior list is sorted so that the lowest index is first in the first baseline.
"""
pos_reds = get_pos_reds(antpos, bl_error_tol=bl_error_tol, include_autos=include_autos)
return add_pol_reds(pos_reds, pols=pols, pol_mode=pol_mode)
def filter_reds(reds, bls=None, ex_bls=None, ants=None, ex_ants=None, ubls=None, ex_ubls=None,
pols=None, ex_pols=None, antpos=None, min_bl_cut=None, max_bl_cut=None,
max_dims=None, min_dim_size=1):
'''
Filter redundancies to include/exclude the specified bls, antennas, unique bl groups and polarizations.
Also allows filtering reds by removing antennas so that the number of generalized tip/tilt degeneracies
is no more than max_dims. Arguments are evaluated, in order of increasing precedence: (pols, ex_pols,
ubls, ex_ubls, bls, ex_bls, ants, ex_ants, min_bl_cut, max_bl_cut, max_dims).
Args:
reds: list of lists of redundant (i,j,pol) baseline tuples, e.g. the output of get_reds().
Not modified in place.
bls (optional): baselines to include. Baselines of the form (i,j,pol) include that specific
visibility. Baselines of the form (i,j) are broadcast across all polarizations present in reds.
ex_bls (optional): same as bls, but excludes baselines.
ants (optional): antennas to include. Only baselines where both antenna indices are in ants
are included. Antennas of the form (i,pol) include that antenna/pol. Antennas of the form i are
broadcast across all polarizations present in reds.
ex_ants (optional): same as ants, but excludes antennas.
ubls (optional): redundant (unique baseline) groups to include. Each baseline in ubls is taken to
represent the redundant group containing it. Baselines of the form (i,j) are broadcast across all
polarizations, otherwise (i,j,pol) selects a specific redundant group.
ex_ubls (optional): same as ubls, but excludes groups.
pols (optional): polarizations to include in reds. e.g. ['nn','ee','ne','en']. Default includes all
polarizations in reds.
ex_pols (optional): same as pols, but excludes polarizations.
antpos: dictionary of antenna positions in the form {ant_index: np.array([x,y,z])}. 1D and 2D also OK.
min_bl_cut: cut redundant groups with average baseline lengths shorter than this. Same units as antpos
which must be specified.
max_bl_cut: cut redundant groups with average baselines lengths longer than this. Same units as antpos
which must be specified.
max_dims: maximum number of dimensions required to specify antenna positions (up to some arbitary shear).
This is equivalent to the number of generalized tip/tilt phase degeneracies of redcal that are fixed
with remove_degen() and must be later abscaled. 2 is a classically "redundantly calibratable" planar
array. More than 2 usually arises with subarrays of redundant baselines. None means no filtering.
min_dim_size: minimum number of atennnas allowed with non-zero positions in a given dimension. This
allows filtering out of antennas where only a few are responsible for adding a dimension. Ignored
if max_dims is None. Default 1 means no further filtering based on the number of anntenas in that dim.
Return:
reds: list of lists of redundant baselines in the same form as input reds.
'''
# pre-processing step to ensure that reds complies with hera_cal polarization conventions
reds = [[(i, j, comply_pol(p)) for (i, j, p) in gp] for gp in reds]
if pols is None: # if no pols are provided, deduce them from the red
pols = set(gp[0][2] for gp in reds)
if ex_pols:
pols = set(p for p in pols if p not in ex_pols)
reds = [gp for gp in reds if gp[0][2] in pols]
def expand_bls(gp):
gp3 = [(g[0], g[1], p) for g in gp if len(g) == 2 for p in pols]
return set(gp3 + [g for g in gp if len(g) == 3])
antpols = set(sum([list(split_pol(p)) for p in pols], []))
def expand_ants(gp):
gp2 = [(g, p) for g in gp if not hasattr(g, '__len__') for p in antpols]
return set(gp2 + [g for g in gp if hasattr(g, '__len__')])
def split_bls(bls):
return set(split_bl(bl) for bl in bls)
if ubls or ex_ubls:
bl2gp = {}
for i, gp in enumerate(reds):
for key in gp:
bl2gp[key] = bl2gp[reverse_bl(key)] = i
if ubls:
ubls = expand_bls(ubls)
ubls = set(bl2gp[key] for key in ubls if key in bl2gp)
else:
ubls = set(range(len(reds)))
if ex_ubls:
ex_ubls = expand_bls(ex_ubls)
ex_ubls = set(bl2gp[bl] for bl in ex_ubls if bl in bl2gp)
else:
ex_ubls = set()
reds = [gp for i, gp in enumerate(reds) if i in ubls and i not in ex_ubls]
if bls is not None:
bls = expand_bls(bls)
else: # default to set of all baselines
bls = set(key for gp in reds for key in gp)
if ex_bls:
ex_bls = expand_bls(ex_bls)
ex_bls |= set(reverse_bl(k) for k in ex_bls) # put in reverse baselines
bls = set(k for k in bls if k not in ex_bls)
if ants:
ants = expand_ants(ants)
bls = set(join_bl(i, j) for i, j in split_bls(bls) if i in ants and j in ants)
if ex_ants:
ex_ants = expand_ants(ex_ants)
bls = set(join_bl(i, j) for i, j in split_bls(bls) if i not in ex_ants and j not in ex_ants)
bls |= set(reverse_bl(k) for k in bls) # put in reverse baselines
reds = [[key for key in gp if key in bls] for gp in reds]
reds = [gp for gp in reds if len(gp) > 0]
if min_bl_cut is not None or max_bl_cut is not None:
assert antpos is not None, 'antpos must be passed in if min_bl_cut or max_bl_cut is specified.'
lengths = [np.mean([np.linalg.norm(antpos[bl[1]] - antpos[bl[0]]) for bl in gp]) for gp in reds]
reds = [gp for gp, l in zip(reds, lengths) if ((min_bl_cut is None or l > min_bl_cut)
and (max_bl_cut is None or l < max_bl_cut))]
if max_dims is not None:
while True:
# Compute idealized antenna positions from redundancies. Given the reds (a list of list of
# redundant baselines), these positions will be coordinates in a vector space that reproduce
# the ideal antenna positions with a set of unknown basis vectors. The dimensionality of
# idealized_antpos is determined in reds_to_antpos by first assigning each antenna its own
# dimension and then inferring how many of those are simply linear combinations of others
# using the redundancies. The number of dimensions is equivalent to the number of generalized
# tip/tilt degeneracies of redundant calibration.
idealized_antpos = reds_to_antpos(reds, tol=IDEALIZED_BL_TOL)
ia_array = np.array(list(idealized_antpos.values()))
# if we've removed all antennas, break
if len(ia_array) == 0:
break
# if we're down to 1 dimension, the mode finding below won't work. Just check Nants >= min_dim_size.
if len(ia_array[0]) <= 1:
if len(ia_array) >= min_dim_size:
break
# Find dimension with the most common mode idealized coordinate value. This is supposed to look
# for outlier antennas off the redundant grid small sub-arrays that cannot be redundantly
# calibrated without adding more degeneracies than desired.
mode_count, mode_value, mode_dim = 0, 0, 0
for dim, coords in enumerate(np.array(list(idealized_antpos.values())).T):
rounded_coords = coords.round(decimals=int(np.floor(-np.log10(IDEALIZED_BL_TOL))))
unique, counts = np.unique(rounded_coords, return_counts=True)
if np.max(counts) > mode_count:
mode_count = np.max(counts)
mode_value = unique[counts == mode_count][0]
mode_dim = dim
# Cut all antennas not part of that mode to reduce the dimensionality of idealized_antpos
new_ex_ants = [ant for ant in idealized_antpos if
np.abs(idealized_antpos[ant][mode_dim] - mode_value) > IDEALIZED_BL_TOL]
# If we're down to the reqested number of dimensions and if the next filtering would
# eliminate more antennas than min_dim_size, then break instead of filtering.
if len(ia_array[0]) <= max_dims:
if (len(new_ex_ants) >= min_dim_size):
break
reds = filter_reds(reds, ex_ants=new_ex_ants)
return reds
def combine_reds(reds1, reds2, unfiltered_reds=None):
'''Combine the groups in two separate lists of redundancies into one which
does not contain repeats.
Arguments:
reds1: list of list or redundant baseline tuples to combine
reds2: another list of list or redundant baseline tuples to combine
unfiltered_reds: optional list of list of redundant baselines. Used to combine
non-overlapping but redundant groups to get the most accurate answers.
Returns:
combined_reds: list of list of redundant baselines, combining reds1 and reds2
as much as possible
'''
if unfiltered_reds is not None:
bls_to_use = set([bl for reds in [reds1, reds2] for red in reds for bl in red])
combined_reds = filter_reds(unfiltered_reds, bls=bls_to_use)
else:
# if unfilterd reds is not provided, try to combine the groups as much as possible.
# N.B. this can still give wrong answers if there are baselines redundant with each
# other but unique to reds1 and reds2 respectively
reds1_sets = [set(red) for red in reds1]
reds1_map = {bl: n for n, red1_set in enumerate(reds1_sets) for bl in red1_set}
for red2 in reds2:
# figure out if any baseline in this group corresponds to a baseline in reds1
matched_group = None
for bl in red2:
if bl in reds1_map:
matched_group = reds1_map[bl]
if matched_group is not None:
# if there's a match, take the union of the two groups
reds1_sets[matched_group] |= set(red2)
else:
# otherwise, make a new group
reds1_sets.append(set(red2))
combined_reds = [list(red) for red in reds1_sets]
# sort result in a useful way
combined_reds = [sorted(red, key=lambda x: x[0]) for red in sorted(combined_reds, key=len, reverse=True)]
return combined_reds
def reds_to_antpos(reds, tol=1e-10):
'''Computes a set of antenna positions consistent with the given redundancies.
Useful for projecting out phase slope degeneracies, see https://arxiv.org/abs/1712.07212
Arguments:
reds: list of lists of redundant baseline tuples, either (i,j,pol) or (i,j)
tol: level for two vectors to be considered equal (enabling dimensionality reduction)
Returns:
antpos: dictionary of antenna positions in the form {ant_index: np.ndarray}.
These positions may differ from the true positions of antennas by an arbitrary
linear transformation. The dimensionality of the positions will be the minimum
necessary to describe all redundancies (non-redundancy introduces extra dimensions),
though the most used dimensions will come first.
'''
ants = set([ant for red in reds for bl in red for ant in bl[:2]])
# start with all antennas (except the first) having their own dimension, then reduce the dimensionality
antpos = {ant: np.array([1. if d + 1 == i else 0. for d in range(len(ants) - 1)])
for i, ant in enumerate(ants)}
for red in reds:
for bl in red:
# look for vectors in the higher dimensional space that are equal to 0
delta = (antpos[bl[1]] - antpos[bl[0]]) - (antpos[red[0][1]] - antpos[red[0][0]])
if np.linalg.norm(delta) > tol: # this baseline can help us reduce the dimensionality
dim_to_elim = np.max(np.arange(len(delta))[np.abs(delta) > tol])
antpos = {ant: np.delete(pos - pos[dim_to_elim] / delta[dim_to_elim] * delta, dim_to_elim)
for ant, pos in antpos.items()}
# remove any all-zero dimensions
dim_to_elim = np.argwhere(np.sum(np.abs(list(antpos.values())), axis=0) == 0).flatten()
antpos = {ant: np.delete(pos, dim_to_elim) for ant, pos in antpos.items()}
# sort dims so that most-used dimensions come first
dim_usage = np.sum(np.abs(list(antpos.values())) > tol, axis=0)
antpos = {ant: pos[np.argsort(dim_usage)[::-1]] for ant, pos in antpos.items()}
return antpos
def make_sol_finite(sol):
'''Replaces nans and infs in solutions, which are usually the result of visibilities that are
identically equal to 0. Modifies sol (which is a dictionary with gains and visibilities) in place,
replacing visibilities with 0.0s and gains with 1.0s'''
for k, v in sol.items():
not_finite = ~np.isfinite(v)
if len(k) == 3: # visibilities
sol[k][not_finite] = np.zeros_like(v[not_finite])
elif len(k) == 2: # gains
sol[k][not_finite] = np.ones_like(v[not_finite])
def remove_degen_gains(reds, gains, degen_gains=None, mode='phase', pol_mode='1pol'):
""" Removes degeneracies from gains (or replaces them with those in gains). This
function is nominally intended for use with firstcal, which returns (phase/delay) solutions
for antennas only.
Args:
gains: dictionary that contains gain solutions in the {(index,antpol): np.array} format.
degen_gains: Optional dictionary in the same format as gains. Gain amplitudes and phases
in degen_sol replace the values of sol in the degenerate subspace of redcal. If
left as None, average gain amplitudes will be 1 and average phase terms will be 0.
For logcal/lincal/omnical, putting firstcal solutions in here can help avoid structure
associated with phase-wrapping issues.
mode: 'phase' or 'complex', indicating whether the gains are passed as phases (e.g. delay
or phi in e^(i*phi)), or as the complex number itself. If 'phase', only phase degeneracies
removed. If 'complex', both phase and amplitude degeneracies are removed.
pol_mode: polarization mode of redundancies. Can be '1pol', '2pol', '4pol', or '4pol_minV'.
Returns:
new_gains: gains with degeneracy removal/replacement performed
"""
# Check supported pol modes
assert pol_mode in ['1pol', '2pol', '4pol', '4pol_minV'], f'Unrecognized pol_mode: {pol_mode}'
assert mode in ('phase', 'complex'), 'Unrecognized mode: %s' % mode
ants = list(set(ant for gp in reds for bl in gp for ant in split_bl(bl) if ant in gains))
gainPols = np.array([ant[1] for ant in ants]) # gainPols is list of antpols, one per antenna
antpols = list(set(gainPols))
# if mode is 2pol, run as two 1pol remove degens
if pol_mode == '2pol':
pol0_gains = {k: v for k, v in gains.items() if k[1] == antpols[0]}
pol1_gains = {k: v for k, v in gains.items() if k[1] == antpols[1]}
reds0 = [gp for gp in reds if gp[0][-1] in join_pol(antpols[0], antpols[0])]
reds1 = [gp for gp in reds if gp[0][-1] in join_pol(antpols[1], antpols[1])]
new_gains = remove_degen_gains(reds0, pol0_gains, degen_gains=degen_gains, mode=mode, pol_mode='1pol')
new_gains.update(remove_degen_gains(reds1, pol1_gains, degen_gains=degen_gains, mode=mode, pol_mode='1pol'))
return new_gains
# Extract gains and degenerate gains and put into numpy arrays
gainSols = np.array([gains[ant] for ant in ants])
if degen_gains is None:
if mode == 'phase':
degenGains = np.array([np.zeros_like(gains[ant]) for ant in ants])
else: # complex
degenGains = np.array([np.ones_like(gains[ant]) for ant in ants])
else:
degenGains = np.array([degen_gains[ant] for ant in ants])
# Build matrices for projecting gain degeneracies
antpos = reds_to_antpos(reds)
positions = np.array([antpos[ant[0]] for ant in ants])
if pol_mode == '1pol' or pol_mode == '4pol_minV':
# In 1pol and 4pol_minV, the phase degeneracies are 1 overall phase and 2 tip-tilt terms
# Rgains maps gain phases to degenerate parameters (either average phases or phase slopes)
Rgains = np.hstack((positions, np.ones((positions.shape[0], 1))))
else: # pol_mode is '4pol'
# two columns give sums for two different polarizations
phasePols = np.vstack((gainPols == antpols[0], gainPols == antpols[1])).T
Rgains = np.hstack((positions, phasePols))
# Mgains is like (AtA)^-1 At in linear estimator formalism. It's a normalized estimator of degeneracies
Mgains = np.linalg.pinv(Rgains.T.dot(Rgains)).dot(Rgains.T)
# degenToRemove is the amount we need to move in the degenerate subspace
if mode == 'phase':
# Fix phase terms only
degenToRemove = np.einsum('ij,jkl', Mgains, gainSols - degenGains)
gainSols -= np.einsum('ij,jkl', Rgains, degenToRemove)
else: # working on complex data
# Fix phase terms
degenToRemove = np.einsum('ij,jkl', Mgains, np.angle(gainSols * np.conj(degenGains)))
gainSols *= np.exp(np.complex64(-1j) * np.einsum('ij,jkl', Rgains, degenToRemove))
# Fix abs terms: fixes the mean abs product of gains (as they appear in visibilities)
for pol in antpols:
meanSqAmplitude = np.mean([np.abs(g1 * g2) for (a1, p1), g1 in gains.items()
for (a2, p2), g2 in gains.items()
if p1 == pol and p2 == pol and a1 != a2], axis=0)
degenMeanSqAmplitude = np.mean([(np.ones_like(gains[k1]) if degen_gains is None
else np.abs(degen_gains[k1] * degen_gains[k2]))
for k1 in gains.keys() for k2 in gains.keys()
if k1[1] == pol and k2[1] == pol and k1[0] != k2[0]], axis=0)
gainSols[gainPols == pol] *= (degenMeanSqAmplitude / meanSqAmplitude)**.5
# Create new solutions dictionary
new_gains = {ant: gainSol for ant, gainSol in zip(ants, gainSols)}
return new_gains
class RedSol():
'''Object for containing solutions to redundant calibraton, namely gains and
unique-baseline visibilities, along with a variety of convenience methods.'''
def __init__(self, reds, gains={}, vis={}, sol_dict={}):
'''Initializes RedSol object.
Arguments:
reds: list of lists of redundant baseline tuples, e.g. (0, 1, 'ee')
gains: optional dictionary. Maps keys like (1, 'Jee') to complex
numpy arrays of gains of size (Ntimes, Nfreqs).
vis: optional dictionary or DataContainer. Maps keys like (0, 1, 'ee')
to complex numpy arrays of visibilities of size (Ntimes, Nfreqs).
May only contain at most one visibility per unique baseline group.
sol_dict: optional dictionary. Maps both gain keys and visibilitity keys
to numpy arrays. Must be empty if gains or vis is not.
'''
if len(sol_dict) > 0:
if (len(gains) > 0) or (len(vis) > 0):
raise ValueError('If sol_dict is not empty, both gains and vis must be.')
self.gains = {key: val for key, val in sol_dict.items() if len(key) == 2}
vis = {key: val for key, val in sol_dict.items() if len(key) == 3}
else:
self.gains = gains
self.reds = reds
self.vis = RedDataContainer(vis, reds=self.reds)
def __getitem__(self, key):
'''Get underlying gain or visibility, depending on the length of the key.'''
if len(key) == 3: # visibility key
return self.vis[key]
elif len(key) == 2: # antenna-pol key
return self.gains[key]
else:
raise KeyError('RedSol keys should be length-2 (for gains) or length-3 (for visibilities).')
def __setitem__(self, key, value):
'''Set underlying gain or visibility, depending on the length of the key.'''
if len(key) == 3: # visibility key
self.vis[key] = value
elif len(key) == 2: # antenna-pol key
self.gains[key] = value
else:
raise KeyError('RedSol keys should be length-2 (for gains) or length-3 (for visibilities).')
def __contains__(self, key):
'''Returns True if key is a gain key or a redundant visbility key, False otherwise.'''
return (key in self.gains) or (key in self.vis)
def __iter__(self):
'''Iterate over gain keys, then iterate over visibility keys.'''
return chain(self.gains, self.vis)
def __len__(self):
'''Returns the total number of entries in self.gains or self.vis.'''
return len(self.gains) + len(self.vis)
def keys(self):
'''Iterate over gain keys, then iterate over visibility keys.'''
return self.__iter__()
def values(self):
'''Iterate over gain values, then iterate over visibility values.'''
return chain(self.gains.values(), self.vis.values())
def items(self):
'''Returns the keys and values of the gains, then over those of the visibilities.'''
return chain(self.gains.items(), self.vis.items())
def get(self, key, default=None):
'''Returns value associated with key, but default if key is not found.'''
if key in self:
return self[key]
else:
return default
def make_sol_finite(self):
'''Replaces nans and infs in this object, see redcal.make_sol_finite() for details.'''
make_sol_finite(self)
def remove_degen(self, degen_sol=None, inplace=True):
""" Removes degeneracies from solutions (or replaces them with those in degen_sol).
Arguments:
sol: dictionary (or RedSol) that contains both visibility and gain solutions in the
{(ind1,ind2,pol): np.array} and {(index,antpol): np.array} formats respectively
degen_sol: Optional dictionary or RedSol, formatted like sol. Gain amplitudes and phases
in degen_sol replace the values of sol in the degenerate subspace of redcal. If
left as None, average gain amplitudes will be 1 and average phase terms will be 0.
Visibilties in degen_sol are ignored, so this can also be a dictionary of gains.
inplace: If True, replaces self.vis and self.gains. If False, returns a new RedSol object.
Returns:
new_sol: if not inplace, RedSol with degeneracy removal/replacement performed
"""
old_gains = self.gains
new_gains = remove_degen_gains(self.reds, old_gains, degen_gains=degen_sol, mode='complex',
pol_mode=parse_pol_mode(self.reds))
if inplace:
calibrate_in_place(self.vis, new_gains, old_gains=old_gains)
self.gains = new_gains
else:
new_vis = deepcopy(self.vis)
calibrate_in_place(new_vis, new_gains, old_gains=old_gains)
return RedSol(self.reds, gains=new_gains, vis=new_vis)
def gain_bl(self, bl):
'''Return gain for baseline bl = (ai, aj).
Arguments:
bl: tuple, baseline to be split into antennas indexing gain.
Returns:
gain: gi * conj(gj)
'''
ai, aj = split_bl(bl)
return self.gains[ai] * np.conj(self.gains[aj])
def model_bl(self, bl):
'''Return visibility data model (gain * vissol) for baseline bl
Arguments:
bl: tuple, baseline to return model for
Returns:
vis: gi * conj(gj) * vis[bl]
'''
return self.gain_bl(bl) * self.vis[bl]
def calibrate_bl(self, bl, data, copy=True):
'''Return calibrated data for baseline bl
Arguments:
bl: tuple, baseline from which to divide out gains
data: numpy array of data to calibrate
copy: if False, apply calibration to data in place
Returns:
vis: data / (gi * conj(gj))
'''
gij = self.gain_bl(bl)
if copy:
return np.divide(data, gij, where=(gij != 0))
else:
np.divide(data, gij, out=data, where=(gij != 0))
return data
def update_vis_from_data(self, data, wgts=None, reds_to_update=None):
'''Performs redundant averaging of data using reds and gains stored in this RedSol object and
stores the result as the redundant solution.
Arguments:
data: DataContainer containing visibilities to redundantly average.
wgts: optional DataContainer weighting visibilities in averaging.
If not provided, it is assumed that all data are uniformly weighted.
If provided, must include all keys in reds_to_update (or self.reds).
If weights add to 0, for any time/freq in any redundant group, all baselines
that are not flagged for all times and freqs are weighted equally.
reds_to_update: list of reds to update, otherwise update all.
Returns:
None
'''
if reds_to_update is None:
reds_to_update = self.reds
else:
self.vis.build_red_keys(combine_reds(self.reds, reds_to_update))
self.reds = self.vis.reds
for grp in reds_to_update:
wgts_here = ([wgts[bl] for bl in grp] if (wgts is not None) else None)
if wgts_here is not None:
if np.all([np.all(wgt == 0) for wgt in wgts_here]):
# If the entire group has 0 weight, exclude this group from averaging
continue
if np.any(np.sum(wgts_here, axis=0) == 0):
# If any time/freq is completely flagged, perform uniform averaging over not-completely flagged baselines
not_totally_flagged_bls = [bl for bl, wgt in zip(grp, wgts_here) if not np.all(wgt == 0)]
flag_waterfall = np.all([wgts[bl] == 0 for bl in not_totally_flagged_bls], axis=0)
wgts_here = [np.where(flag_waterfall, 1, wgts[bl]) if bl in not_totally_flagged_bls else wgts[bl] for bl in grp]
self.vis[grp[0]] = np.average([self.calibrate_bl(bl, data[bl]) for bl in grp], axis=0, weights=wgts_here)
def extend_vis(self, data, wgts=None, reds_to_solve=None):
'''Performs redundant averaging of ubls not already solved for in RedSol.vis
and adds them to RedSol.vis
Arguments:
data: DataContainer containing visibilities to redundantly average.
wgts: optional DataContainer weighting visibilities in averaging.
If not provided, it is assumed that all data are uniformly weighted.
If provided, must include all keys in reds_to_update (or self.reds).
If weights add to 0, for any time/freq in any redundant group, all baselines
that are not flagged for all times and freqs are weighted equally.
reds_to_solve: subset of reds to update, otherwise update all
Returns:
None
'''
if reds_to_solve is None:
unsolved_reds = [gp for gp in self.reds if not gp[0] in self.vis]
reds_to_solve = filter_reds(unsolved_reds, ants=self.gains.keys())
self.update_vis_from_data(data, wgts=wgts, reds_to_update=reds_to_solve)
def extend_gains(self, data, wgts={}, extended_reds=None):
'''Extend redundant solutions to antennas gains not already solved for
using redundant baseline solutions in RedSol.vis, adding them to RedSol.gains.
Arguments:
data: DataContainer containing visibilities to redundantly average.
wgts: optional DataContainer weighting visibilities in averaging.
If not provided, it is assumed that all data are uniformly weighted.
extended_reds: Broader list of reds to update, otherwise use existing reds.
Returns:
None
'''
if extended_reds is None:
extended_reds = self.reds
gsum = {}
gwgt = {}
for grp in extended_reds:
try:
u = self.vis[grp[0]] # RedDataContainer will take care of mapping.
except(KeyError):
# no redundant visibility solution for this group, so skip
continue
# loop through baselines and select ones that have one solved antenna
# and one unsolved to solve for.
for bl in grp:
a_i, a_j = split_bl(bl)
if a_i not in self.gains:
if a_j not in self.gains:
# no solution for either antenna in this baseline, so skip
continue
_gsum = data[bl] * (u.conj() * self[a_j])
_gwgt = np.abs(u)**2 * np.abs(self[a_j])**2
if len(wgts) > 0:
_gsum *= wgts[bl]
_gwgt *= wgts[bl]
gsum[a_i] = gsum.get(a_i, 0) + _gsum
gwgt[a_i] = gwgt.get(a_i, 0) + _gwgt
elif a_j not in self.gains:
_gsum = data[bl].conj() * (u * self[a_i])
_gwgt = np.abs(u)**2 * np.abs(self[a_i])**2
if len(wgts) > 0:
_gsum *= wgts[bl]
_gwgt *= wgts[bl]
gsum[a_j] = gsum.get(a_j, 0) + _gsum
gwgt[a_j] = gwgt.get(a_j, 0) + _gwgt
for k in gsum.keys():
self[k] = np.divide(gsum[k], gwgt[k], where=(gwgt[k] > 0))
def chisq(self, data, data_wgts, gain_flags=None):
"""Computes chi^2 defined as: chi^2 = sum_ij(|data_ij - model_ij * g_i conj(g_j)|^2 * wgts_ij)
and also a chisq_per_antenna which is the same sum but with fixed i.
Arguments:
data: DataContainer mapping baseline-pol tuples like (0,1,'nn') to complex data of shape (Nt, Nf).
data_wgts: multiplicative weights with which to combine chisq per visibility. Usually
equal to (visibility noise variance)**-1.
gain_flags: optional dictionary mapping ant-pol keys like (1,'Jnn') to a boolean flags waterfall
with the same shape as the data. Default: None, which means no per-antenna flagging.
Returns:
chisq: numpy array with the same shape each visibility of chi^2 calculated as above. If the
inferred pol_mode from reds (see redcal.parse_pol_mode) is '1pol' or '2pol', this is a
dictionary mapping antenna polarization (e.g. 'Jnn') to chi^2. Otherwise, there is a single
chisq (because polarizations mix) and this is a numpy array.
chisq_per_ant: dictionary mapping ant-pol keys like (1,'Jnn') to chisq per antenna, computed as
above but keeping i fixed and varying only j.
"""
split_by_antpol = parse_pol_mode(self.reds) in ['1pol', '2pol']
chisq, _, chisq_per_ant, _ = utils.chisq(data, self.vis, data_wgts=data_wgts,
gains=self.gains, gain_flags=gain_flags,
reds=self.reds, split_by_antpol=split_by_antpol)
return chisq, chisq_per_ant
def normalized_chisq(self, data, data_wgts):
'''Computes chi^2 and chi^2 per antenna with proper normalization per DoF.
Arguments:
data: DataContainer mapping baseline-pol tuples like (0,1,'nn') to complex data of shape (Nt, Nf).
data_wgts: multiplicative weights with which to combine chisq per visibility. Usually
equal to (visibility noise variance)**-1.
Returns:
chisq: chi^2 per degree of freedom for the calibration solution. If the inferred pol_mode from
reds (see redcal.parse_pol_mode) is '1pol' or '2pol', this is a dictionary mapping antenna
polarization (e.g. 'Jnn') to chi^2. Otherwise, there is a single chisq (because polarizations
mix) and this is a numpy array.
chisq_per_ant: dictionary mapping ant-pol tuples like (1,'Jnn') to the sum of all chisqs for
visibilities that an antenna participates in, DoF normalized using predict_chisq_per_ant
'''
chisq, chisq_per_ant = normalized_chisq(data, data_wgts, self.reds, self.vis, self.gains)
return chisq, chisq_per_ant
def _check_polLists_minV(polLists):
"""Given a list of unique visibility polarizations (e.g. for each red group), returns whether
they are all either single identical polarizations (e.g. 'nn') or both cross polarizations
(e.g. ['ne','en']) so that the 4pol_minV can be assumed."""
for polList in polLists:
if len(polList) == 1:
if split_pol(polList[0])[0] != split_pol(polList[0])[1]:
return False
elif len(polList) == 2:
if polList[0] != conj_pol(polList[1]) or split_pol(polList[0])[0] == split_pol(polList[0])[1]:
return False
else:
return False
return True
def parse_pol_mode(reds):
"""Based on reds, figures out the pol_mode.
Args:
reds: list of list of baselines (with polarizations) considered redundant
Returns:
pol_mode: polarization mode of calibration
'1pol': 1 antpol and 1 vispol (e.g. 'Jnn' and 'nn'). Default.
'2pol': 2 antpols, no cross-vispols (e.g. 'Jnn','Jee' and 'nn','ee')
'4pol': 2 antpols, 4 vispols (e.g. 'Jnn','Jee' and 'nn','ne','en','ee')
'4pol_minV': 2 antpols, 4 vispols in data but assuming V_ne = V_en in model
'unrecognized_pol_mode': something else
"""
pols = list(set([bl[2] for bls in reds for bl in bls]))
antpols = list(set([antpol for pol in pols for antpol in split_pol(pol)]))
if len(pols) == 1 and len(antpols) == 1:
return '1pol'
elif len(pols) == 2 and np.all([split_pol(pol)[0] == split_pol(pol)[1] for pol in pols]):
return '2pol'
elif len(pols) == 4 and len(antpols) == 2:
polLists = [list(set([bl[2] for bl in bls])) for bls in reds]
polListLens = np.array([len(polList) for polList in polLists])
if np.all(polListLens == 1) and len(pols) == 4 and len(antpols) == 2:
return '4pol'
elif _check_polLists_minV(polLists) and len(pols) == 4 and len(antpols) == 2:
return '4pol_minV'
else:
return 'unrecognized_pol_mode'
else:
return 'unrecognized_pol_mode'
class OmnicalSolver(linsolve.LinProductSolver):
def __init__(self, data, sol0, wgts={}, gain=.3, **kwargs):
"""Set up a nonlinear system of equations of the form g_i * g_j.conj() * V_mdl = V_ij
to linearize via the Omnical algorithm described in HERA Memo 50
(scripts/notebook/omnical_convergence.ipynb).
Args:
data: Dictionary that maps nonlinear product equations, written as valid python-interpetable
strings that include the variables in question, to (complex) numbers or numpy arrarys.
Variables with trailing underscores '_' are interpreted as complex conjugates (e.g. x*y_
parses as x * y.conj()).
sol0: Dictionary mapping all variables (as keyword strings) to their starting guess values.
This is the point that is Taylor expanded around, so it must be relatively close to the
true chi^2 minimizing solution. In the same format as that produced by
linsolve.LogProductSolver.solve() or linsolve.LinProductSolver.solve().
wgts: Dictionary that maps equation strings from data to real weights to apply to each
equation. Weights are treated as 1/sigma^2. All equations in the data must have a weight
if wgts is not the default, {}, which means all 1.0s.
gain: The fractional step made toward the new solution each iteration. Default is 0.3.
Values in the range 0.1 to 0.5 are generally safe. Increasing values trade speed
for stability.
**kwargs: keyword arguments of constants (python variables in keys of data that
are not to be solved for) which are passed to linsolve.LinProductSolver.
"""
linsolve.LinProductSolver.__init__(self, data, sol0, wgts=wgts, **kwargs)
self.gain = np.float32(gain) # float32 to avoid accidentally promoting data to doubles.
def _get_ans0(self, sol, keys=None):
'''Evaluate the system of equations given input sol.
Specify keys to evaluate only a subset of the equations.'''
if keys is None:
keys = self.keys
_sol = {k + '_': v.conj() for k, v in sol.items() if k.startswith('g')}
_sol.update(sol)
return {k: eval(k, _sol) for k in keys}
def solve_iteratively(self, conv_crit=1e-10, maxiter=50, check_every=4, check_after=1,
wgt_func=lambda x: 1., verbose=False):
"""Repeatedly solves and updates solution until convergence or maxiter is reached.
Returns a meta-data about the solution and the solution itself.
Args:
conv_crit: A convergence criterion (default 1e-10) below which to stop iterating.
Converegence is measured L2-norm of the change in the solution of all the variables
divided by the L2-norm of the solution itself.
maxiter: An integer maximum number of iterations to perform before quitting. Default 50.
check_every: Compute convergence and updates weights every Nth iteration (saves computation). Default 4.
check_after: Start computing convergence and updating weights after the first N iterations. Default 1.
wgt_func: a function f(abs^2 * wgt) operating on weighted absolute differences between
data and model that returns an additional data weighting to apply to when calculating
chisq and updating parameters. Example: lambda x: np.where(x>0, 5*np.tanh(x/5)/x, 1)
clamps deviations to 5 sigma. Default is no additional weighting (lambda x: 1.).
Returns: meta, sol
meta: a dictionary with metadata about the solution, including
iter: the number of iterations taken to reach convergence (or maxiter), with dimensions of the data.
chisq: the chi^2 of the solution produced by the final iteration, with dimensions of the data.
conv_crit: the convergence criterion evaluated at the final iteration, with dimensions of the data.
sol: a dictionary of complex solutions with variables as keys, with dimensions of the data.
"""
sol = self.sol0
terms = [(linsolve.get_name(gi), linsolve.get_name(gj), linsolve.get_name(uij))
for term in self.all_terms for (gi, gj, uij) in term]
dmdl_u = self._get_ans0(sol)
abs2_u = {k: np.abs(self.data[k] - dmdl_u[k])**2 * self.wgts[k] for k in self.keys}
chisq = sum([v * wgt_func(v) for v in abs2_u.values()])
update = np.where(chisq > 0)
abs2_u = {k: v[update] for k, v in abs2_u.items()}
# variables with '_u' are flattened and only include pixels that need updating
dmdl_u = {k: v[update].flatten() for k, v in dmdl_u.items()}
# wgts_u hold the wgts the user provides
wgts_u = {k: (v * np.ones(chisq.shape, dtype=np.float32))[update].flatten()
for k, v in self.wgts.items()}
# clamp_wgts_u adds additional sigma clamping done by wgt_func.
# abs2_u holds abs(data - mdl)**2 * wgt (i.e. noise-weighted deviations), which is
# passed to wgt_func to determine any additional weighting (to, e.g., clamp outliers).
clamp_wgts_u = {k: v * wgt_func(abs2_u[k]) for k, v in wgts_u.items()}
sol_u = {k: v[update].flatten() for k, v in sol.items()}
iters = np.zeros(chisq.shape, dtype=int)
conv = np.ones_like(chisq)
for i in range(1, maxiter + 1):
if verbose:
print('Beginning iteration %d/%d' % (i, maxiter))
if (i % check_every) == 1:
# compute data wgts: dwgts = sum(V_mdl^2 / n^2) = sum(V_mdl^2 * wgts)
# don't need to update data weighting with every iteration
# clamped weighting is passed to dwgts_u, which is used to update parameters
dwgts_u = {k: dmdl_u[k] * dmdl_u[k].conj() * clamp_wgts_u[k] for k in self.keys}
sol_wgt_u = {k: 0 for k in sol.keys()}
for k, (gi, gj, uij) in zip(self.keys, terms):
w = dwgts_u[k]
sol_wgt_u[gi] += w
sol_wgt_u[gj] += w
sol_wgt_u[uij] += w
dw_u = {k: v[update] * dwgts_u[k] for k, v in self.data.items()}
sol_sum_u = {k: 0 for k in sol_u.keys()}
for k, (gi, gj, uij) in zip(self.keys, terms):
# compute sum(wgts * V_meas / V_mdl)
numerator = dw_u[k] / dmdl_u[k]
sol_sum_u[gi] += numerator
sol_sum_u[gj] += numerator.conj()
sol_sum_u[uij] += numerator
new_sol_u = {k: v * ((1 - self.gain) + self.gain * sol_sum_u[k] / sol_wgt_u[k])
for k, v in sol_u.items()}
dmdl_u = self._get_ans0(new_sol_u)
# check if i % check_every is 0, which is purposely one less than the '1' up at the top of the loop
if i < maxiter and (i < check_after or (i % check_every) != 0):
# Fast branch when we aren't expensively computing convergence/chisq
sol_u = new_sol_u
else:
# Slow branch when we compute convergence/chisq
abs2_u = {k: np.abs(v[update] - dmdl_u[k])**2 * wgts_u[k] for k, v in self.data.items()}
new_chisq_u = sum([v * wgt_func(v) for v in abs2_u.values()])
chisq_u = chisq[update]
gotbetter_u = (chisq_u > new_chisq_u)
where_gotbetter_u = np.where(gotbetter_u)
update_where = tuple(u[where_gotbetter_u] for u in update)
chisq[update_where] = new_chisq_u[where_gotbetter_u]
iters[update_where] = i
new_sol_u = {k: np.where(gotbetter_u, v, sol_u[k]) for k, v in new_sol_u.items()}
deltas_u = [v - sol_u[k] for k, v in new_sol_u.items()]
conv_u = np.sqrt(sum([(v * v.conj()).real for v in deltas_u])
/ sum([(v * v.conj()).real for v in new_sol_u.values()]))
conv[update_where] = conv_u[where_gotbetter_u]
for k, v in new_sol_u.items():
sol[k][update] = v
update_u = np.where((conv_u > conv_crit) & gotbetter_u)
if update_u[0].size == 0 or i == maxiter:
meta = {'iter': iters, 'chisq': chisq, 'conv_crit': conv}
return meta, sol
dmdl_u = {k: v[update_u] for k, v in dmdl_u.items()}
wgts_u = {k: v[update_u] for k, v in wgts_u.items()}
sol_u = {k: v[update_u] for k, v in new_sol_u.items()}
abs2_u = {k: v[update_u] for k, v in abs2_u.items()}
clamp_wgts_u = {k: v * wgt_func(abs2_u[k]) for k, v in wgts_u.items()}
update = tuple(u[update_u] for u in update)
if verbose:
print(' <CHISQ> = %f, <CONV> = %f, CNT = %d', (np.mean(chisq), np.mean(conv), update[0].size))
def _wrap_phs(phs, wrap_pnt=(np.pi / 2)):
'''Adjust phase wrap point to be [-wrap_pnt, 2pi-wrap_pnt)'''
return (phs + wrap_pnt) % (2 * np.pi) - wrap_pnt
def _flip_frac(offsets, flipped=set(), flip_pnt=(np.pi / 2)):
'''Calculate the fraction of (bl1, bl2) pairings an antenna is involved
in which have large phase offsets.'''
cnt = {}
tot = {}
for (bl1, bl2), off in offsets.items():
ijmn = split_bl(bl1) + split_bl(bl2)
num_in_flipped = sum([int(ant in flipped) for ant in ijmn])
for ant in ijmn:
tot[ant] = tot.get(ant, 0) + 1
if off > flip_pnt and num_in_flipped % 2 == 0:
cnt[ant] = cnt.get(ant, 0) + 1
flip_frac = [(k, v / tot[k]) for k, v in cnt.items()]
return flip_frac
def _find_flipped(offsets, flip_pnt=(np.pi / 2), maxiter=100):
'''Given a dict of (bl1, bl2) keys and phase offset vals, identify
antennas which are likely to have a np.pi phase offset.'''
flipped = set()
for i in range(maxiter):
flip_frac = _flip_frac(offsets, flipped=flipped, flip_pnt=flip_pnt)
changed = False
for (ant, frac) in flip_frac:
if frac > 0.5:
changed = True
if ant in flipped:
flipped.remove(ant)
else:
flipped.add(ant)
if not changed:
break
return flipped
def _firstcal_align_bls(bls, freqs, data, norm=True, wrap_pnt=(np.pi / 2)):
'''Given a redundant group of bls, find per-baseline dly/off params that
bring them into phase alignment using hierarchical pairing.'''
fftfreqs = np.fft.fftfreq(freqs.shape[-1], np.median(np.diff(freqs)))
dtau = fftfreqs[1] - fftfreqs[0]
grps = [(bl,) for bl in bls] # start with each bl in its own group
_data = {bl: data[bl[0]] for bl in grps}
Ntimes, Nfreqs = data[bls[0]].shape
times = np.arange(Ntimes)
dly_off_gps = {}